Example #1
0
    def iter_train(self, only_cached=False, _top_epoch=False, ewc_params=None):
        epoch = -1
        base_path = util.path_model_trainer(self.ranker, self.vocab, self,
                                            self.dataset)
        context = {
            'epoch':
            epoch,
            'batch_size':
            self.config['batch_size'],
            'batches_per_epoch':
            self.config['batches_per_epoch'],
            'num_microbatches':
            1 if self.config['grad_acc_batch'] == 0 else
            self.config['batch_size'] // self.config['grad_acc_batch'],
            'device':
            self.device,
            'base_path':
            base_path,
        }

        files = trainers.misc.PathManager(base_path)
        self.logger.info(
            f'train path: {base_path}, batches_per_epoch : {context["batches_per_epoch"]}'
        )
        b_count = context['batches_per_epoch'] * context[
            'num_microbatches'] * self.batch_size

        ranker = self.ranker.to(self.device)
        optimizer = self.create_optimizer()

        if f'{epoch}.p' not in files['weights']:
            ranker.save(files['weights'][f'{epoch}.p'])
        if f'{epoch}.p' not in files['optimizer']:
            torch.save(optimizer.state_dict(),
                       files['optimizer'][f'{epoch}.p'])

        context.update({
            'ranker': lambda: ranker,
            'ranker_path': files['weights'][f'{epoch}.p'],
            'optimizer': lambda: optimizer,
            'optimizer_path': files['optimizer'][f'{epoch}.p'],
        })

        #load ranker/optimizer
        if _top_epoch:
            # trainer.pipeline=msmarco_train_bm25_k1-0.82_b-0.68.100_mspairs
            __path = "/".join(
                base_path.split("/")[:-1]) + "/" + self.config['pipeline']
            self.logger.info(f'loading prev model : {__path}')
            w_path = os.path.join(__path, 'weights', '-2.p')
            oppt_path = os.path.join(__path, 'optimizer', '-2.p')
            ranker.load(w_path)
            optimizer.load_state_dict(torch.load(oppt_path))

            if w_path != files['weights']['-2.p']:
                copyfile(w_path, files['weights']['-2.p'])
                copyfile(oppt_path, files['optimizer']['-2.p'])

            context.update({
                'ranker': lambda: ranker,
                'ranker_path': files['weights'][f'-2.p'],
                'optimizer': lambda: optimizer,
                'optimizer_path': files['optimizer'][f'-2.p'],
            })

        yield context  # before training

        while True:
            context = dict(context)

            epoch = context['epoch'] = context['epoch'] + 1
            if epoch in files['complete.tsv']:
                context.update({
                    'loss':
                    files['loss.txt'][epoch],
                    'data_loss':
                    files['data_loss.txt'][epoch],
                    'losses': {},
                    'acc':
                    files['acc.tsv'][epoch],
                    'unsup_acc':
                    files['unsup_acc.tsv'][epoch],
                    'ranker':
                    _load_ranker(ranker, files['weights'][f'{epoch}.p']),
                    'ranker_path':
                    files['weights'][f'{epoch}.p'],
                    'optimizer':
                    _load_optimizer(optimizer,
                                    files['optimizer'][f'{epoch}.p']),
                    'optimizer_path':
                    files['optimizer'][f'{epoch}.p'],
                    'cached':
                    True,
                })
                if not only_cached:
                    self.fast_forward(b_count)  # skip this epoch
                yield context
                continue

            if only_cached:
                break  # no more cached

            # forward to previous versions (if needed)
            ranker = context['ranker']()
            optimizer = context['optimizer']()
            ranker.train()

            context.update({
                'loss': 0.0,
                'losses': {},
                'acc': 0.0,
                'unsup_acc': 0.0,
            })

            # LOAD EWC
            _optpar_params = {}
            _fisher_params = {}
            print("EWC LOAD")
            for task in ewc_params.tasks:
                print("Aded task", task)
                _optpar_params[task] = {}
                _fisher_params[task] = {}

                for name, param in ranker.named_parameters():
                    if param.requires_grad:
                        #for name, param in ranker.state_dict().items():
                        optpar_path = ewc_params.getOptpar(task, name)
                        fisher_path = ewc_params.getFisher(task, name)

                        _optpar_params[task][name] = pickle.load(
                            open(optpar_path, "rb")).cuda()
                        _fisher_params[task][name] = pickle.load(
                            open(fisher_path, "rb")).cuda()

            with tqdm(leave=False,
                      total=b_count,
                      ncols=100,
                      desc=f'train {epoch}') as pbar:
                for b in range(context['batches_per_epoch']):
                    for _ in range(context['num_microbatches']):
                        self.epoch = epoch
                        train_batch_result = self.train_batch()
                        #print("BATCH FINALIZED")
                        losses = train_batch_result['losses']
                        loss_weights = train_batch_result['loss_weights']
                        acc = train_batch_result.get('acc')
                        unsup_acc = train_batch_result.get('unsup_acc')

                        #EWC
                        if not self._ewc:
                            for task in ewc_params.tasks:
                                #print("EWC task", task)
                                for name, param in ranker.named_parameters():
                                    if param.requires_grad:
                                        fisher = _fisher_params[task][name]
                                        optpar = _optpar_params[task][name]
                                        losses['data'] += (
                                            fisher * (optpar - param).pow(2)
                                        ).sum() * ewc_params.ewc_lambda

                                #for name, param in ranker.state_dict().items():
                                #print(name,param)
                                # optpar_path = ewc_params.getOptpar(task, name)
                                # fisher_path = ewc_params.getFisher(task, name)

                                # optpar = torch.load(optpar_path) #pickle.load(open(optpar_path,"rb"))
                                # #torch.load('featurs.pkl',map_location=torch.device('cpu'))
                                # param = (optpar - param)
                                # #torch.cuda.empty_cache()
                                # #gc.collect()

                                # fisher = torch.load(fisher_path) #pickle.load(open(fisher_path,"rb"))
                                # losses['data'] += (fisher * param.pow(2)).sum() * ewc_params.ewc_lambda
                                # #torch.cuda.empty_cache()
                                # #gc.collect()

                        losses['data'] = losses['data'].mean()
                        loss = sum(
                            losses[k] * loss_weights.get(k, 1.)
                            for k in losses) / context['num_microbatches']

                        context['loss'] += loss.item()
                        for lname, lvalue in losses.items():
                            context['losses'].setdefault(lname, 0.)
                            context['losses'][lname] += lvalue.item(
                            ) / context['num_microbatches']

                        if acc is not None:
                            context['acc'] += acc.item(
                            ) / context['num_microbatches']
                        if unsup_acc is not None:
                            context['unsup_acc'] += unsup_acc.item(
                            ) / context['num_microbatches']

                        if loss.grad_fn is not None:
                            if hasattr(optimizer, 'backward'):
                                optimizer.backward(loss)
                            else:
                                loss.backward()
                        else:
                            self.logger.warn(
                                'loss has no grad_fn; skipping batch')
                        pbar.update(self.batch_size)

                    postfix = {
                        'loss': context['loss'] / (b + 1),
                    }
                    for lname, lvalue in context['losses'].items():
                        if lname in loss_weights and loss_weights[lname] != 1.:
                            postfix[
                                f'{lname}({loss_weights[lname]})'] = lvalue / (
                                    b + 1)
                        else:
                            postfix[lname] = lvalue / (b + 1)

                    if postfix['loss'] == postfix['data']:
                        del postfix['data']

                    pbar.set_postfix(postfix)
                    optimizer.step()
                    optimizer.zero_grad()

            context.update({
                'ranker':
                lambda: ranker,
                'ranker_path':
                files['weights'][f'{epoch}.p'],
                'optimizer':
                lambda: optimizer,
                'optimizer_path':
                files['optimizer'][f'{epoch}.p'],
                'loss':
                context['loss'] / context['batches_per_epoch'],
                'losses': {
                    k: v / context['batches_per_epoch']
                    for k, v in context['losses'].items()
                },
                'acc':
                context['acc'] / context['batches_per_epoch'],
                'unsup_acc':
                context['unsup_acc'] / context['batches_per_epoch'],
                'cached':
                False,
            })

            if self._ewc:
                #print("EWC iteration")
                params_for_ewc = {
                    n: p
                    for n, p in ranker.named_parameters() if p.requires_grad
                }
                yield params_for_ewc
                #yield  ranker.state_dict()
                # EWC implementation from https://github.com/ContinualAI/colab/blob/master/notebooks/intro_to_continual_learning.ipynb
                # fisher_dict = {}
                # optpar_dict = {}

                # # gradients accumulated can be used to calculate fisher
                # for name, param in ranker.state_dict().items():
                #     # print(name)
                #     # print(param)
                #     optpar_dict[name] = param.clone()
                #     fisher_dict[name] = param.clone().pow(2)

                # for name, param in ranker.named_parameters():
                #     print(name)
                #     print(param.grad)
                #     optpar_dict[name] = param.data.clone()
                #     fisher_dict[name] = param.grad.data.clone().pow(2)

                #pbar.set_postfix(postfix)
                optimizer.step()
                optimizer.zero_grad()
                #yield {'fisher':fisher_dict, 'optpar':optpar_dict}

            # save stuff
            ranker.save(files['weights'][f'{epoch}.p'])
            torch.save(optimizer.state_dict(),
                       files['optimizer'][f'{epoch}.p'])
            files['loss.txt'][epoch] = context['loss']
            for lname, lvalue in context['losses'].items():
                files[f'loss_{lname}.txt'][epoch] = lvalue
            files['acc.tsv'][epoch] = context['acc']
            files['unsup_acc.tsv'][epoch] = context['unsup_acc']
            files['complete.tsv'][epoch] = 1  # mark as completed

            yield context
Example #2
0
 def _load_ranker_weights_epoch(self, ranker, vocab, trainer, dataset):
     epcoh = self.config['epoch']
     base_path = util.path_model_trainer(ranker, vocab, trainer, dataset)
     weight_path = os.path.join(base_path, 'weights', f'{epcoh}.p')
     self._load_ranker_weights_file_path(ranker, weight_path)
Example #3
0
    def run(self):
        validator = self.valid_pred.pred_ctxt()

        top_epoch, top_value, top_train_ctxt, top_valid_ctxt = None, None, None, None
        prev_train_ctxt = None

        file_output = {
            'ranker': self.trainer.ranker.path_segment(),
            'vocab': self.trainer.vocab.path_segment(),
            'trainer': self.trainer.path_segment(),
            'dataset': self.trainer.dataset.path_segment(),
            'valid_ds': self.valid_pred.dataset.path_segment(),
            'validation_metric': self.config['val_metric'],
            'logfile': util.path_log()
        }

        # initialize dataset(s)
        if not self.config['skip_ds_init']:
            self.trainer.dataset.init(force=False)
            self.valid_pred.dataset.init(force=False)
            if self.config['test']:
                self.test_pred.dataset.init(force=False)

        base_path_g = None
        ewc_params = {}

        # initialize EWC
        base_path = util.path_model_trainer(self.trainer.ranker,
                                            self.trainer.vocab, self.trainer,
                                            self.trainer.dataset)
        ewc_path = "/".join(base_path.split("/")
                            [:-1]) + "/ewc-" + self.config['ewc'] + ".pickle"
        task_name = self.trainer.dataset.path_segment().split("_")[0]

        try:
            my_ewc = pickle.load(open(ewc_path, "rb"))
        except (OSError, IOError) as e:
            my_ewc = EWCValues(path=ewc_path,
                               ewc_lambda=float(self.config['ewc']))
        #print("EWC PATH: ",ewc_path, task_name)

        _train_it = self.trainer.iter_train(
            only_cached=self.config['only_cached'],
            _top_epoch=self.config.get('finetune'),
            ewc_params=my_ewc)
        for train_ctxt in _train_it:

            if self.config.get('onlytest'):
                base_path_g = train_ctxt['base_path']
                self.logger.debug(f'[catfog] skipping training')
                top_train_ctxt = train_ctxt
                break

            if prev_train_ctxt is not None and top_epoch is not None and prev_train_ctxt is not top_train_ctxt:
                self._purge_weights(prev_train_ctxt)

            if train_ctxt['epoch'] >= 0 and not self.config['only_cached']:
                message = self._build_train_msg(train_ctxt)

                if train_ctxt['cached']:
                    self.logger.debug(f'[train] [cached] {message}')
                else:
                    self.logger.debug(f'[train] {message}')

            if train_ctxt['epoch'] == -1 and not self.config['initial_eval']:
                continue

            valid_ctxt = dict(validator(train_ctxt))

            message = self._build_valid_msg(valid_ctxt)
            if valid_ctxt['epoch'] >= self.config['warmup']:
                if self.config['val_metric'] == '':
                    top_epoch = valid_ctxt['epoch']
                    top_train_ctxt = train_ctxt
                    top_valid_ctxt = valid_ctxt
                elif top_value is None or valid_ctxt['metrics'][
                        self.config['val_metric']] > top_value:
                    message += ' <---'
                    top_epoch = valid_ctxt['epoch']
                    top_value = valid_ctxt['metrics'][
                        self.config['val_metric']]
                    if top_train_ctxt is not None:
                        self._purge_weights(top_train_ctxt)
                    top_train_ctxt = train_ctxt
                    top_valid_ctxt = valid_ctxt
            else:
                if prev_train_ctxt is not None:
                    self._purge_weights(prev_train_ctxt)

            if not self.config['only_cached']:
                if valid_ctxt['cached']:
                    self.logger.debug(f'[valid] [cached] {message}')
                else:
                    self.logger.info(f'[valid] {message}')

            if top_epoch is not None:
                epochs_since_imp = valid_ctxt['epoch'] - top_epoch
                if self.config[
                        'early_stop'] > 0 and epochs_since_imp >= self.config[
                            'early_stop']:
                    self.logger.warn(
                        'stopping after epoch {epoch} ({early_stop} epochs with no '
                        'improvement to {val_metric})'.format(
                            **valid_ctxt, **self.config))
                    break

            if train_ctxt['epoch'] >= self.config['max_epoch']:
                self.logger.warn(
                    'stopping after epoch {max_epoch} (max_epoch)'.format(
                        **self.config))
                break

            prev_train_ctxt = train_ctxt

        if not self.config.get('onlytest'):
            self.logger.info('top validation epoch={} {}={}'.format(
                top_epoch, self.config['val_metric'], top_value))

            self.logger.info(f'[catfog: top_train_ctxt] {top_train_ctxt}')
            file_output.update({
                'valid_epoch': top_epoch,
                'valid_run': top_valid_ctxt['run_path'],
                'valid_metrics': top_valid_ctxt['metrics'],
            })

        # save top train epoch for faster testing without needing the retraining phase
        if not self.config.get('onlytest'):
            #pickle.dump(top_epoch, open( top_train_ctxt['base_path']+"/top_epoch.pickle", "wb") )
            # move best to -2.p
            self.trainer.save_best(top_epoch, top_train_ctxt['base_path'])

            # EWC, recover parms from model after extra epoch
            self.trainer.setewc()
            ewc_params = next(_train_it)
            #ewc_params.cpu()

            # EWC implementation from https://github.com/ContinualAI/colab/blob/master/notebooks/intro_to_continual_learning.ipynb
            fisher_dict = {}
            optpar_dict = {}

            # gradients accumulated can be used to calculate fisher
            for name, param in ewc_params.items():
                # print(name)
                # print(param)

                optpar_path = my_ewc.getOptpar(task_name, name)
                pickle.dump(param, open(optpar_path, "wb"))
                #optpar_dict[name] = param.clone()

                fisher_path = my_ewc.getFisher(task_name, name)
                param = param.pow(2)
                pickle.dump(param, open(fisher_path, "wb"))
                #fisher_dict[name] = param.clone().pow(2)

                my_ewc.addNew(task_name, name)

            ####
            # load EWC object
            # get task ID  -> up before the loop train_iter
            ####
            #print("EWC params", ewc_params)
            #my_ewc.setValues(task_name, ewc_params)

            #my_ewc.setValues(task_name, {'fisher':fisher_dict, 'optpar':optpar_dict})
            ###
            # save EWC object
            ###
            pickle.dump(my_ewc, open(ewc_path, "wb"))

        if self.config.get(
                'onlytest'
        ):  # for onlytest use also finetune=true, to load best epoch at first iteration
            self.logger.debug(f'[catfog] loading top context')
            #top_epoch = pickle.load(open(base_path_g+"/top_epoch.pickle", "rb"))
            #self.logger.debug(f'[catfog] loading top context ... {top_epoch} epoch')
            #top_train_ctxt = self.trainer.trainCtx(top_epoch)
            self.logger.debug(
                f'[catfog] Top epoch context: {dict(top_train_ctxt)}')

        if self.config['test']:
            self.logger.info(f'Starting load ranker')
            top_train_ctxt['ranker'] = onir.trainers.base._load_ranker(
                top_train_ctxt['ranker'](), top_train_ctxt['ranker_path'])

            self.logger.info(f'Starting test predictor run')
            with self.logger.duration('testing'):
                test_ctxt = self.test_pred.run(top_train_ctxt)

            file_output.update({
                'test_ds': self.test_pred.dataset.path_segment(),
                'test_run': test_ctxt['run_path'],
                'test_metrics': test_ctxt['metrics'],
            })

        with open(util.path_modelspace() + '/val_test.jsonl', 'at') as f:
            json.dump(file_output, f)
            f.write('\n')

        if not self.config.get('onlytest'):
            self.logger.info('valid run at {}'.format(valid_ctxt['run_path']))
        if self.config['test']:
            self.logger.info('test run at {}'.format(test_ctxt['run_path']))
        if not self.config.get('onlytest'):
            self.logger.info('valid ' + self._build_valid_msg(top_valid_ctxt))
        if self.config['test']:
            self.logger.info('test  ' + self._build_valid_msg(test_ctxt))
            self._write_metrics_file(test_ctxt)
Example #4
0
    def iter_train(self, only_cached=False, _top_epoch=False):
        epoch = -1
        base_path = util.path_model_trainer(self.ranker, self.vocab, self, self.dataset)
        context = {
            'epoch': epoch,
            'batch_size': self.config['batch_size'],
            'batches_per_epoch': self.config['batches_per_epoch'],
            'num_microbatches': 1 if self.config['grad_acc_batch'] == 0 else self.config['batch_size'] // self.config['grad_acc_batch'],
            'device': self.device,
            'base_path': base_path,
        }

        files = trainers.misc.PathManager(base_path)
        self.logger.info(f'train path: {base_path}, batches_per_epoch : {context["batches_per_epoch"]}')
        b_count = context['batches_per_epoch'] * context['num_microbatches'] * self.batch_size

        ranker = self.ranker.to(self.device)
        optimizer = self.create_optimizer()

        if f'{epoch}.p' not in files['weights']:
            ranker.save(files['weights'][f'{epoch}.p'])
        if f'{epoch}.p' not in files['optimizer']:
            torch.save(optimizer.state_dict(), files['optimizer'][f'{epoch}.p'])
        
        context.update({
            'ranker': lambda: ranker,
            'ranker_path': files['weights'][f'{epoch}.p'],
            'optimizer': lambda: optimizer,
            'optimizer_path': files['optimizer'][f'{epoch}.p'],
        })

        #load ranker/optimizer
        if _top_epoch:
            __path="/".join(base_path.split("/")[:-1])+"/"+self.config['pipeline'] # trainer.pipeline=msmarco_train_bm25_k1-0.82_b-0.68.100_mspairs
            self.logger.info(f'loading prev model : {__path}')
            w_path = os.path.join(__path, 'weights','-2.p')
            oppt_path = os.path.join(__path, 'optimizer','-2.p')
            ranker.load(w_path)
            optimizer.load_state_dict(torch.load(oppt_path))

            if w_path!=files['weights']['-2.p']:
                copyfile(w_path,files['weights']['-2.p'] )
                copyfile(oppt_path,files['optimizer']['-2.p'] )

            context.update({ 
                'ranker': lambda: ranker,
                'ranker_path': files['weights'][f'-2.p'],
                'optimizer': lambda: optimizer,
                'optimizer_path': files['optimizer'][f'-2.p'],
            })

        yield context # before training

        while True:
            context = dict(context)

            epoch = context['epoch'] = context['epoch'] + 1
            if epoch in files['complete.tsv']:
                context.update({
                    'loss': files['loss.txt'][epoch],
                    'data_loss': files['data_loss.txt'][epoch],
                    'losses': {},
                    'acc': files['acc.tsv'][epoch],
                    'unsup_acc': files['unsup_acc.tsv'][epoch],
                    'ranker': _load_ranker(ranker, files['weights'][f'{epoch}.p']),
                    'ranker_path': files['weights'][f'{epoch}.p'],
                    'optimizer': _load_optimizer(optimizer, files['optimizer'][f'{epoch}.p']),
                    'optimizer_path': files['optimizer'][f'{epoch}.p'],
                    'cached': True,
                })
                if not only_cached:
                    self.fast_forward(b_count) # skip this epoch
                yield context
                continue

            if only_cached:
                break # no more cached

            # forward to previous versions (if needed)
            ranker = context['ranker']()
            optimizer = context['optimizer']()
            ranker.train()

            context.update({
                'loss': 0.0,
                'losses': {},
                'acc': 0.0,
                'unsup_acc': 0.0,
            })

            with tqdm(leave=False, total=b_count, ncols=100, desc=f'train {epoch}') as pbar:
                for b in range(context['batches_per_epoch']):
                    for _ in range(context['num_microbatches']):
                        self.epoch = epoch
                        train_batch_result = self.train_batch()
                        #print("BATCH FINALIZED")
                        losses = train_batch_result['losses']
                        loss_weights = train_batch_result['loss_weights']
                        acc = train_batch_result.get('acc')
                        unsup_acc = train_batch_result.get('unsup_acc')

                        loss = sum(losses[k] * loss_weights.get(k, 1.) for k in losses) / context['num_microbatches']

                        context['loss'] += loss.item()
                        for lname, lvalue in losses.items():
                            context['losses'].setdefault(lname, 0.)
                            context['losses'][lname] += lvalue.item() / context['num_microbatches']

                        if acc is not None:
                            context['acc'] += acc.item() / context['num_microbatches']
                        if unsup_acc is not None:
                            context['unsup_acc'] += unsup_acc.item() / context['num_microbatches']

                        if loss.grad_fn is not None:
                            if hasattr(optimizer, 'backward'):
                                optimizer.backward(loss)
                            else:
                                loss.backward()
                        else:
                            self.logger.warn('loss has no grad_fn; skipping batch')
                        pbar.update(self.batch_size)

                    postfix = {
                        'loss': context['loss'] / (b + 1),
                    }
                    for lname, lvalue in context['losses'].items():
                        if lname in loss_weights and loss_weights[lname] != 1.:
                            postfix[f'{lname}({loss_weights[lname]})'] = lvalue / (b + 1)
                        else:
                            postfix[lname] = lvalue / (b + 1)

                    if postfix['loss'] == postfix['data']:
                        del postfix['data']

                    pbar.set_postfix(postfix)
                    optimizer.step()
                    optimizer.zero_grad()

            context.update({
                'ranker': lambda: ranker,
                'ranker_path': files['weights'][f'{epoch}.p'],
                'optimizer': lambda: optimizer,
                'optimizer_path': files['optimizer'][f'{epoch}.p'],
                'loss': context['loss'] / context['batches_per_epoch'],
                'losses': {k: v / context['batches_per_epoch'] for k, v in context['losses'].items()},
                'acc': context['acc'] / context['batches_per_epoch'],
                'unsup_acc': context['unsup_acc'] / context['batches_per_epoch'],
                'cached': False,
            })

            # save stuff
            ranker.save(files['weights'][f'{epoch}.p'])
            torch.save(optimizer.state_dict(), files['optimizer'][f'{epoch}.p'])
            files['loss.txt'][epoch] = context['loss']
            for lname, lvalue in context['losses'].items():
                files[f'loss_{lname}.txt'][epoch] = lvalue
            files['acc.tsv'][epoch] = context['acc']
            files['unsup_acc.tsv'][epoch] = context['unsup_acc']
            files['complete.tsv'][epoch] = 1 # mark as completed

            yield context