Example #1
0
def normal_training(config):
    device = torch.device(config['device'])
    print('Using device', device)
    exp, model, train_dataloader, eval_dataloader, loss_func = setup_training(
        config)
    exp.set_name(config['experiment_name'])
    model.train()
    model = model.to(device)
    optimizers = get_optimizers(model, config)
    evaluator = Evaluation(eval_dataloader, config)

    num_examples = 0
    for epoch in range(config['training']['training_epochs']):
        for idx, batch in enumerate(train_dataloader):
            batch = (batch[0].to(device), batch[1].to(device))
            num_examples += len(batch[0])
            loss, train_accuracy = training_step(batch, model, optimizers,
                                                 loss_func)
            if idx % config['training']['log_every_n_batches'] == 0:
                print(epoch, num_examples, loss.detach().cpu().numpy())
                exp.log_metric('train_loss',
                               loss.detach().cpu().numpy(),
                               step=num_examples,
                               epoch=epoch)

            if idx % config['training']['eval_every_n_batches'] == 0:
                results = evaluator.eval_model(model, loss_func)
                for metric in results:
                    print(metric, results[metric])
                    exp.log_metric(metric,
                                   results[metric],
                                   step=num_examples,
                                   epoch=epoch)
Example #2
0
    print('Model loaded.')

    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    log_dict = helpers.flatten_dict(config)
    log_dict.update({'trainable_params': n_params})
    exp.log_parameters(log_dict)

    test_dataset = data.CSVDatasetsMerger(helpers.get_datasets_paths(config, 'test'))
    test_dataloader = DataLoader(test_dataset,
                       batch_size=config['evaluation']['eval_batch_size'],
                       shuffle=False,
                       drop_last=False,
                       num_workers=config['evaluation']['n_eval_workers'],
                       collate_fn=text_proc)

    evaluator = Evaluation(test_dataloader, config)

    print('Testing ...')
    results, assets, image_fns = evaluator.eval_model(model, finished_training=True)
    print('Finished testing. Uploading ...')

    exp.log_metrics(results, step=0, epoch=0)
    [exp.log_asset_data(asset, step=0) for asset in assets]
    [exp.log_image(fn, step=0) for fn in image_fns]

    print('Finished uploading.')




Example #3
0
class TuneTrainable(Trainable):
    def _setup(self, config):
        inject_tuned_hyperparameters(config, config)
        os.chdir(os.path.dirname(os.path.realpath(__file__)))
        print('Trainable got the following config after injection', config)
        self.config = config
        self.device = self.config['device']
        self.exp, self.model, self.train_dataloader, self.eval_dataloader = setup_training(
            self.config)
        self.exp.set_name(config['experiment_name'] + self._experiment_id)
        self.exp_name = config['experiment_name'] + self._experiment_id
        self.exp.send_notification(title='Experiment ' +
                                   str(self._experiment_id) + ' ended')
        self.train_data_iter = iter(self.train_dataloader)
        self.model = self.model.to(self.device)
        self.model.train()
        n_params = sum(p.numel() for p in self.model.parameters()
                       if p.requires_grad)
        log_dict = flatten_dict(config)
        log_dict.update({'trainable_params': n_params})
        self.exp.log_parameters(log_dict)
        self.optimizers = get_optimizers(self.model, self.config)
        self.evaluator = Evaluation(self.eval_dataloader, self.config)
        self.num_examples = 0
        self.batch_idx = 0
        self.epoch = 1
        self.ewma = EWMA(beta=0.75)
        self.last_accu = -1.0
        self.max_accu = -1.0
        self.back_prop_every_n_batches = config['training'][
            'back_prop_every_n_batches']
        self.checkpoint_best = config['training']['checkpoint_best']

    def get_batch(self):
        try:
            batch = next(self.train_data_iter)
            return batch

        except StopIteration:
            self.train_data_iter = iter(self.train_dataloader)
            batch = next(self.train_data_iter)
            self.batch_idx = 0
            self.epoch += 1
            return batch

    def _train(self):
        total_log_step_loss = 0
        total_log_step_train_accu = 0
        total_log_step_n = 0

        [opt.zero_grad() for opt in self.optimizers]
        while True:
            batch = self.get_batch()
            self.batch_idx += 1
            self.num_examples += len(batch[0])
            batch = (batch[0].to(self.device), batch[1].to(self.device))
            loss, train_accu = training_step(
                batch,
                self.model,
                self.optimizers,
                step=(self.batch_idx % self.back_prop_every_n_batches == 0))
            total_log_step_loss += loss.cpu().detach().numpy()
            total_log_step_train_accu += train_accu
            total_log_step_n += 1

            if self.batch_idx % self.config['training'][
                    'log_every_n_batches'] == 0:
                avg_loss = total_log_step_loss / total_log_step_n
                avg_accu = total_log_step_train_accu / total_log_step_n
                total_log_step_n = 0
                print(f'{Fore.YELLOW}Total number of seen examples:',
                      self.num_examples, 'Average loss of current log step:',
                      avg_loss, 'Average train accuracy of current log step:',
                      avg_accu, f"{Style.RESET_ALL}")
                self.exp.log_metric('train_loss',
                                    avg_loss,
                                    step=self.num_examples,
                                    epoch=self.epoch)
                self.exp.log_metric('train_accuracy',
                                    avg_accu,
                                    step=self.num_examples,
                                    epoch=self.epoch)
                total_log_step_loss = 0
                total_log_step_train_accu = 0

            if (self.batch_idx +
                    1) % self.config['training']['eval_every_n_batches'] == 0:
                results, assets, image_fns = self.evaluator.eval_model(
                    self.model)
                print(self.config['tune']['discriminating_metric'],
                      results[self.config['tune']['discriminating_metric']])
                self.exp.log_metrics(results,
                                     step=self.num_examples,
                                     epoch=self.epoch)
                [
                    self.exp.log_asset_data(asset, step=self.num_examples)
                    for asset in assets
                ]
                [
                    self.exp.log_image(fn, step=self.num_examples)
                    for fn in image_fns
                ]

                accu_diff_avg = abs(
                    results[self.config['tune']['discriminating_metric']] -
                    self.ewma.get())
                accu_diff_cons = abs(
                    results[self.config['tune']['discriminating_metric']] -
                    self.last_accu)

                no_change_in_accu = 1 if accu_diff_avg < 0.0005 and accu_diff_cons < 0.002 and self.num_examples > 70000 else 0
                self.ewma.update(
                    results[self.config['tune']['discriminating_metric']])
                self.last_accu = results[self.config['tune']
                                         ['discriminating_metric']]

                if self.max_accu < results[self.config['tune']
                                           ['discriminating_metric']]:
                    self.max_accu = results[self.config['tune']
                                            ['discriminating_metric']]
                    if self.checkpoint_best:
                        self.save_checkpoint('checkpoints',
                                             self.exp_name + '.pt')
                        print(
                            f'{Fore.GREEN}New best model saved.{Style.RESET_ALL}'
                        )

                self.exp.log_metric('max_accuracy',
                                    self.max_accu,
                                    step=self.num_examples,
                                    epoch=self.epoch)

                training_results = {
                    self.config['tune']['discriminating_metric']:
                    self.max_accu,
                    'num_examples': self.num_examples,
                    'no_change_in_accu': no_change_in_accu
                }

                return training_results

    def _save(self, checkpoint_dir):
        return self.save_checkpoint(checkpoint_dir, 'checkpoint_file.pt')

    def save_checkpoint(self, checkpoint_dir, fname='checkpoint_file.pt'):
        print(f'{Fore.CYAN}Saving model ...{Style.RESET_ALL}')
        save_dict = {'model_state_dict': self.model.state_dict()}
        for i, optimizer in enumerate(self.optimizers):
            save_dict['op_' + str(i) + '_state_dict'] = optimizer.state_dict()
        torch.save(save_dict, os.path.join(checkpoint_dir, fname))
        return os.path.join(checkpoint_dir, fname)

    def _restore(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        self.model.load_state_dict(checkpoint['model_state_dict'])

        for i, optimizer in enumerate(self.optimizers):
            optimizer.load_state_dict(checkpoint['op_' + str(i) +
                                                 '_state_dict'])

    def stop(self):
        results, assets, image_fns = self.evaluator.eval_model(
            self.model, finished_training=True)
        self.exp.log_metrics(results, step=self.num_examples, epoch=self.epoch)
        [
            self.exp.log_asset_data(asset, step=self.num_examples)
            for asset in assets
        ]
        [self.exp.log_image(fn, step=self.num_examples) for fn in image_fns]

        return super().stop()