def train(args, model, train_dataset, eval_dataset):
	train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8)
	eval_dataloader = DataLoader(eval_dataset, batch_size=args.batch_size, num_workers=8)

	loss_fct = BCELoss()
	optimizer = AdamW(model.parameters(), lr=args.lr)

	print("***** Running training *****")
	print("  Num examples = %d" % (len(train_dataset)))
	print("  Num Val examples = %d" % (len(eval_dataset)))
	print("  Num Epochs = %d" % (args.epochs))
	print("  Batch Size = %d" % (args.batch_size))

	output_dir = join(args.out, args.save)
	if not os.path.exists(output_dir):
		os.makedirs(output_dir)
	log_file = open(join(output_dir, 'log'),'w')

	global_step = 0
	best_val_auc = 0.0
	running_loss = 0.0
	model.zero_grad()
	for epoch in range(args.epochs):
		for step, batch in enumerate(train_dataloader):
			model.train()
			# start_time = time.time()
			xarray, position, token_type_list, mask, ylabel = batch
			xarray = xarray.to(args.device)
			position = position.to(args.device)
			token_type_list = token_type_list.to(args.device)
			mask = mask.to(args.device)
			ylabel = ylabel.to(args.device)
			# batch_end_time = time.time()

			output = model(xarray, position, token_type_list, mask)
			# output_time = time.time()
			loss = loss_fct(output.view(-1).to(torch.float32), ylabel.view(-1).to(torch.float32))
			
			loss.backward()
			optimizer.step()
			model.zero_grad()
			# loss_time = time.time()
			
			running_loss += loss.item()
			global_step += 1

			# print("Batch time",batch_end_time - start_time, "output_time", output_time - batch_end_time, "loss_time", loss_time - output_time)

			# print every logging_step steps
			if global_step % args.logging_step == 0 and global_step != 0:
				eval_result =  eval(args, model, eval_dataloader)
				print('Epoch: %d, Global Step: %d, Loss: %.3f, Eval Loss: %.3f, Eval F1score: %.3f, Eval AUC: %.3f' % (epoch + 1, global_step, (running_loss / args.logging_step) , eval_result['loss'], eval_result['f1'], eval_result['auc']))
				log_file.write('Epoch: %d, Global Step: %d, Loss: %.3f, Eval Loss: %.3f, Eval F1score: %.3f, Eval AUC: %.3f \n' % (epoch + 1, global_step, (running_loss / args.logging_step) , eval_result['loss'], eval_result['f1'], eval_result['auc']))
				running_loss = 0.0

				#If eval accuracy increases, save the model
				if eval_result['auc'] > best_val_auc:
					best_val_auc = eval_result['auc']
					torch.save(model.state_dict(),os.path.join(output_dir, "model_state_dict.pt"),)
					torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
Example #2
0
def run_training(args, ls):
    ls.print('Training started: ' + datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
    # Misc setup
    os.makedirs(args.model_dir, exist_ok=True)
    assert len(args.cnn_filters)%2 == 0
    args.cnn_filters = list(zip(args.cnn_filters[:-1:2], args.cnn_filters[1::2]))
    # Load the vocabs
    vocabs = get_vocabs(os.path.join(args.model_dir, args.vocab_dir))
    bert_tokenizer = None
    if args.with_bert:
        bert_tokenizer = BertEncoderTokenizer.from_pretrained(args.bert_path, do_lower_case=False)
        vocabs['bert_tokenizer'] = bert_tokenizer
    for name in vocabs:
        if name == 'bert_tokenizer':
            continue
        ls.print('Vocab %-20s  size %5d  coverage %.3f' % (name, vocabs[name].size, vocabs[name].coverage))
    # Setup BERT encoder
    bert_encoder = None
    if args.with_bert:
        bert_encoder = BertEncoder.from_pretrained(args.bert_path)
        for p in bert_encoder.parameters():
            p.requires_grad = False
    # Device and random setup
    torch.manual_seed(19940117)
    torch.cuda.manual_seed_all(19940117)
    random.seed(19940117)
    device = torch.device(args.device)
    # Create the model
    ls.print('Setting up the model')
    model = Parser(vocabs,
            args.word_char_dim, args.word_dim, args.pos_dim, args.ner_dim,
            args.concept_char_dim, args.concept_dim,
            args.cnn_filters, args.char2word_dim, args.char2concept_dim,
            args.embed_dim, args.ff_embed_dim, args.num_heads, args.dropout,
            args.snt_layers, args.graph_layers, args.inference_layers, args.rel_dim,
            device, args.pretrained_file, bert_encoder,)
    model = model.to(device)
    # Optimizer and weight decay params
    weight_decay_params = []
    no_weight_decay_params = []
    for name, param in model.named_parameters():
        if name.endswith('bias') or 'layer_norm' in name:
            no_weight_decay_params.append(param)
        else:
            weight_decay_params.append(param)
    grouped_params = [{'params':weight_decay_params, 'weight_decay':1e-4},
                        {'params':no_weight_decay_params, 'weight_decay':0.}]
    optimizer = AdamW(grouped_params, 1., betas=(0.9, 0.999), eps=1e-6)
    # Re-load an existing model if requested
    used_batches = 0
    batches_acm = 0
    if args.resume_ckpt:
        ls.print('Resuming from checkpoint', args.resume_ckpt)
        ckpt = torch.load(args.resume_ckpt)
        model.load_state_dict(ckpt['model'])
        if ckpt.get('optimizer', {}):
            optimizer.load_state_dict(ckpt['optimizer'])
        else:
            ls.print('No optimizer state saved in checkpoint, using default initial optimizer')
        batches_acm = ckpt['batches_acm']
        start_epoch = ckpt['epoch'] + 1
        del ckpt
    else:
        start_epoch = 1     # don't start at 0
    # Load data
    ls.print('Loading training data')
    train_data = DataLoader(vocabs, args.train_data, args.train_batch_size, for_train=True)
    train_data.set_unk_rate(args.unk_rate)
    # Train
    ls.print('Training')
    epoch, loss_avg, concept_loss_avg, arc_loss_avg, rel_loss_avg = 0, 0, 0, 0, 0
    for epoch in range(start_epoch, args.epochs+1):
        st = time.time()
        for batch in train_data:
            model.train()
            batch = move_to_device(batch, model.device)
            concept_loss, arc_loss, rel_loss, graph_arc_loss = model(batch)
            loss = (concept_loss + arc_loss + rel_loss) / args.batches_per_update
            loss_value = loss.item()
            concept_loss_value = concept_loss.item()
            arc_loss_value = arc_loss.item()
            rel_loss_value = rel_loss.item()
            loss_avg = loss_avg * args.batches_per_update * 0.8 + 0.2 * loss_value
            concept_loss_avg = concept_loss_avg * 0.8 + 0.2 * concept_loss_value
            arc_loss_avg = arc_loss_avg * 0.8 + 0.2 * arc_loss_value
            rel_loss_avg = rel_loss_avg * 0.8 + 0.2 * rel_loss_value
            loss.backward()
            used_batches += 1
            if not (used_batches % args.batches_per_update == -1 % args.batches_per_update):
                continue
            batches_acm += 1
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            lr = update_lr(optimizer, args.lr_scale, args.embed_dim, batches_acm, args.warmup_steps)
            optimizer.step()
            optimizer.zero_grad()
        # Summary at the end of the epoch
        dur = time.time() - st
        ls.print('Epoch %4d, Batch %5d, LR %.6f, conc_loss %.3f, arc_loss %.3f, rel_loss %.3f, duration %.1f seconds' %
                    (epoch, batches_acm, lr, concept_loss_avg, arc_loss_avg, rel_loss_avg, dur))
        # Evaluate and save the data every so often
        if (epoch>args.skip_evals or args.resume_ckpt is not None) and epoch % args.eval_every == 0:
            model.eval()
            ls.print('Evaluating and saving the model')
            fname = '%s/epoch%d.pt'%(args.model_dir, epoch)
            optim = optimizer.state_dict() if args.save_optimizer else {}
            torch.save({'args':vars(args), 'model':model.state_dict(), 'batches_acm': batches_acm,
                        'optimizer': optim, 'epoch':epoch}, fname)
            try:
                out_fn = 'epoch%d.pt.dev_generated' % (epoch)
                inference = Inference.build_from_model(model, vocabs)
                f_score, ctr = inference.reparse_annotated_file('.', args.dev_data, args.model_dir, out_fn,
                        print_summary=False)
                ls.print('Smatch F: %.3f.  Wrote %d AMR graphs to %s' % \
                        (f_score, ctr, os.path.join(args.model_dir, out_fn)))
            except:
                ls.print('Exception during generation')
                traceback.print_exc()
            model.train()
    # End time-stamp
    ls.print('Training finished: ' + datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
Example #3
0
class Trainer(object):
    def __init__(
        self,
        model: nn.Module,
        learning_rate: float,
        device: torch.device,
        train_nodes: torch.LongTensor,
        val_nodes: torch.LongTensor,
        test_nodes: torch.LongTensor,
        vocab_size: int,
        results_dir: str,
        validate_every_n_epochs: int,
        save_after_n_epochs: int,
        checkpoint_every_n_epochs: int,
        use_early_stopping: bool,
        early_stopping_epochs: int,
        autodelete_checkpoints: bool,
    ):
        self.device = device
        self.model = model
        self.model.to(self.device)
        self.optimiser = AdamW(
            params=model.parameters(),
            lr=learning_rate,
        )
        self.loss_fn = nn.CrossEntropyLoss()

        assert (
            len(set(train_nodes).intersection(set(val_nodes))) == 0
        ), f'There are overlapping nodes: {len(set(train_nodes).intersection(set(val_nodes)))}'
        assert (
            len(set(train_nodes).intersection(set(test_nodes))) == 0
        ), f'There are overlapping nodes: {len(set(train_nodes).intersection(set(test_nodes)))}'
        self.train_nodes = train_nodes
        self.val_nodes = val_nodes
        self.test_nodes = test_nodes
        self.vocab_size = vocab_size
        print(f'Vocabulary offset: {vocab_size}')

        self.results_dir = results_dir
        self.validate_every_n_epochs = validate_every_n_epochs
        self.save_after_n_epochs = save_after_n_epochs
        self.checkpoint_every_n_epochs = checkpoint_every_n_epochs
        self.use_early_stopping = use_early_stopping
        self.early_stopping_epochs = early_stopping_epochs
        self.has_saved_metric = False
        self._setup_dirs()

        self.metric_of_interest = 'val loss'
        self.best_metric = math.inf
        self.last_epoch_with_improvement = 1
        self.autodelete_checkpoints = autodelete_checkpoints

    def _setup_dirs(self):
        self.ckpt_dir = os.path.join(self.results_dir, 'ckpt')
        self.best_model_dir = os.path.join(self.results_dir, 'best', 'models')
        self.best_preds_dir = os.path.join(self.results_dir, 'best',
                                           'predictions')
        os.makedirs(self.ckpt_dir, exist_ok=True)
        os.makedirs(self.best_model_dir, exist_ok=True)
        os.makedirs(self.best_preds_dir, exist_ok=True)

    def __call__(
        self,
        input_features: torch.FloatTensor,
        adjacency: torch.sparse.FloatTensor,
        labels: torch.LongTensor,
        num_epochs: int,
    ):
        with trange(num_epochs, desc='Training progress: ') as t:
            for epoch_num in range(1, num_epochs + 1):
                train_metrics = self._train_epoch(input_features, adjacency,
                                                  labels)

                if (epoch_num %
                        self.validate_every_n_epochs) == 0 or epoch_num == 1:
                    # Validate and save metrics
                    val_metrics = self._val_epoch(input_features, adjacency,
                                                  labels)
                    save_metrics(
                        file_path=os.path.join(self.results_dir,
                                               'train-log.jsonl'),
                        epoch_num=epoch_num,
                        train_metrics=train_metrics,
                        val_metrics=val_metrics,
                        is_first_metric_save=not self.has_saved_metric,
                    )
                    self.has_saved_metric = True

                    if epoch_num > self.save_after_n_epochs and (
                            epoch_num % self.checkpoint_every_n_epochs) == 0:
                        # Save model
                        self._checkpoint_model(epoch_num)
                        if self._is_best(val_metrics):
                            self._save_best_model(epoch_num)
                            self._save_test_predictions(
                                input_features, adjacency, labels, epoch_num)

                    if self.use_early_stopping:
                        if self._is_best(val_metrics):
                            self.last_epoch_with_improvement = epoch_num
                        if epoch_num > self.last_epoch_with_improvement + self.early_stopping_epochs:
                            note = f'Breaking on epoch {epoch_num} after no improvement since epoch \
                                {self.last_epoch_with_improvement}'

                            print(note)
                            save_training_notes(
                                file_path=os.path.join(self.results_dir,
                                                       'training-notes.jsonl'),
                                epoch_num=epoch_num,
                                note=note,
                            )

                            break

                else:
                    # if we haven't validated, create an empt val metric dict
                    val_metrics = {'val loss': None}

                t.set_postfix(train_loss=train_metrics['train loss'],
                              val_loss=val_metrics['val loss'])
                t.update()

        return None

    def _train_epoch(
        self,
        input_features: torch.FloatTensor,
        adjacency: torch.sparse.FloatTensor,
        labels: torch.LongTensor,
    ) -> Dict[str, Any]:
        """
        NOTE: Although we pass in all input features and labels,
              we only evaluate the loss on training set node indicies.
        """
        start_time = time.time()

        self.model.train()
        self.optimiser.zero_grad()
        logits = self.model(input_features, adjacency)
        train_loss = self.loss_fn(logits[self.train_nodes + self.vocab_size],
                                  labels[self.train_nodes])

        train_loss.backward()
        self.optimiser.step()

        # print(f'train loss: {train_loss}')

        duration = time.time() - start_time
        return {
            'train epoch duration': duration,
            'train loss': train_loss.item()
        }

    def _val_epoch(
        self,
        input_features: torch.FloatTensor,
        adjacency: torch.sparse.FloatTensor,
        labels: torch.LongTensor,
    ) -> Dict[str, Any]:
        self.model.eval()
        logits = self.model(input_features, adjacency)
        val_loss = self.loss_fn(logits[self.val_nodes + self.vocab_size],
                                labels[self.val_nodes])

        # print(f'val loss: {val_loss}')

        val_accuracy = accuracy(logits[self.val_nodes + self.vocab_size],
                                labels[self.val_nodes],
                                is_logit_output=True)
        return {
            'val loss': val_loss.item(),
            'F-score': None,
            'Accuracy': val_accuracy
        }

    def _checkpoint_model(self, epoch: int) -> None:
        """ Checkpoint to resume training """
        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': self.model.state_dict(),
                'optimiser_state_dict': self.optimiser.state_dict(),
                'loss': self.loss_fn,
            },
            os.path.join(self.ckpt_dir, f'model-{epoch}.pt'),
        )
        # delete exists checkpoints (except for the one we just saved)
        if self.autodelete_checkpoints:
            checkpoints = glob(os.path.join(self.ckpt_dir, f'model-*.pt'))
            old_checkpoints = [
                checkpoint for checkpoint in checkpoints
                if f'model-{epoch}' not in checkpoint
            ]
            for old_checkpoint in old_checkpoints:
                os.remove(old_checkpoint)

    def _save_best_model(self, epoch: int) -> None:
        """ Save best model for inference """
        torch.save(self.model.state_dict(),
                   os.path.join(self.best_model_dir, f'model-{epoch}.pt'))
        remove_previous_best_model(self.best_model_dir, epoch)

    def _save_test_predictions(
        self,
        input_features: torch.FloatTensor,
        adjacency: torch.sparse.FloatTensor,
        labels: torch.LongTensor,
        epoch: int,
    ) -> None:
        """ Save test set predictions for the best model """
        self.model.eval()
        logits = self.model(input_features, adjacency)
        predictions = get_predictions(logits[self.test_nodes +
                                             self.vocab_size],
                                      labels[self.test_nodes],
                                      is_logit_output=True)
        torch.save(
            predictions,
            os.path.join(self.best_preds_dir, f'predictions-{epoch}.pt'))
        remove_previous_best_predictions(self.best_preds_dir, epoch)

    def _is_best(self, val_metrics: Dict[str, float]) -> bool:
        if 'loss' in self.metric_of_interest:
            if val_metrics[self.metric_of_interest] <= self.best_metric:
                self.best_metric = val_metrics[self.metric_of_interest]
                return True
            else:
                return False
        else:
            # Assume we want to maximise it if it is not a loss
            if val_metrics[self.metric_of_interest] > self.best_metric:
                self.best_metric = val_metrics[self.metric_of_interest]
                return True
            else:
                return False

    def save_test_metrics(
        self,
        input_features: torch.FloatTensor,
        adjacency: torch.sparse.FloatTensor,
        labels: torch.LongTensor,
    ) -> None:
        files_in_dir = os.listdir(self.best_preds_dir)
        assert len(
            files_in_dir
        ) == 1, f'Found more than one prediction file in:\n{files_in_dir}'
        test_predictions = torch.load(
            os.path.join(self.best_preds_dir, files_in_dir[0]))
        test_labels = labels[self.test_nodes]

        num_correct = float(
            torch.sum(
                torch.eq(test_predictions.type_as(test_labels), test_labels)))
        test_accuracy = num_correct / len(test_labels)
        test_macro_f1 = f1_score(labels[self.test_nodes],
                                 test_predictions,
                                 average='macro')

        save_dict_to_json(
            {
                'test-accuracy': test_accuracy,
                'test_macro_f1': test_macro_f1
            },
            os.path.join(self.results_dir, 'test-log.jsonl'),
        )
Example #4
0
class Trainer():
    def __init__(self,
                 train_dataloader,
                 test_dataloader,
                 lr,
                 betas,
                 weight_decay,
                 log_freq,
                 with_cuda,
                 model=None):

        cuda_condition = torch.cuda.is_available() and with_cuda
        self.device = torch.device("cuda" if cuda_condition else "cpu")
        print("Use:", "cuda:0" if cuda_condition else "cpu")

        self.model = Classifier_M3().to(self.device)
        self.optim = AdamW(self.model.parameters(),
                           lr=lr,
                           betas=betas,
                           weight_decay=weight_decay)
        self.scheduler = lr_scheduler.CosineAnnealingLR(self.optim, 5)
        self.criterion = nn.BCEWithLogitsLoss()

        if model != None:
            checkpoint = torch.load(model)
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.optim.load_state_dict(checkpoint['optimizer_state_dict'])
            self.epoch = checkpoint['epoch']
            self.criterion = checkpoint['loss']

        if torch.cuda.device_count() > 1:
            self.model = nn.DataParallel(self.model)
        print("Using %d GPUS for Converter" % torch.cuda.device_count())

        self.train_data = train_dataloader
        self.test_data = test_dataloader

        self.log_freq = log_freq
        print("Total Parameters:",
              sum([p.nelement() for p in self.model.parameters()]))

        self.test_loss = []
        self.train_loss = []
        self.train_f1_score = []
        self.test_f1_score = []

    def train(self, epoch):
        self.iteration(epoch, self.train_data)

    def test(self, epoch):
        self.iteration(epoch, self.test_data, train=False)

    def iteration(self, epoch, data_loader, train=True):
        """
        :param epoch: 現在のepoch
        :param data_loader: torch.utils.data.DataLoader
        :param train: trainかtestかのbool値
        """
        str_code = "train" if train else "test"

        data_iter = tqdm(enumerate(data_loader),
                         desc="EP_%s:%d" % (str_code, epoch),
                         total=len(data_loader),
                         bar_format="{l_bar}{r_bar}")

        total_element = 0
        loss_store = 0.0
        f1_score_store = 0.0
        total_correct = 0

        for i, data in data_iter:
            specgram = data[0].to(self.device)
            label = data[2].to(self.device)
            one_hot_label = data[1].to(self.device)
            predict_label = self.model(specgram, train)

            #
            predict_f1_score = get_F1_score(
                label.cpu().detach().numpy(),
                convert_label(predict_label.cpu().detach().numpy()),
                average='micro')

            loss = self.criterion(predict_label, one_hot_label)

            #
            if train:
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()
                self.scheduler.step()

            loss_store += loss.item()
            f1_score_store += predict_f1_score
            self.avg_loss = loss_store / (i + 1)
            self.avg_f1_score = f1_score_store / (i + 1)

            post_fix = {
                "epoch": epoch,
                "iter": i,
                "avg_loss": round(self.avg_loss, 5),
                "loss": round(loss.item(), 5),
                "avg_f1_score": round(self.avg_f1_score, 5)
            }

        data_iter.write(str(post_fix))
        self.train_loss.append(
            self.avg_loss) if train else self.test_loss.append(self.avg_loss)
        self.train_f1_score.append(
            self.avg_f1_score) if train else self.test_f1_score.append(
                self.avg_f1_score)

    def save(self, epoch, file_path="../models/2k/"):
        """
        """
        output_path = file_path + f"crnn_ep{epoch}.model"
        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': self.model.cpu().state_dict(),
                'optimizer_state_dict': self.optim.state_dict(),
                'criterion': self.criterion
            }, output_path)
        self.model.to(self.device)
        print("EP:%d Model Saved on:" % epoch, output_path)
        return output_path

    def export_log(self, epoch, file_path="../../logs/2k/"):
        df = pd.DataFrame({
            "train_loss": self.train_loss,
            "test_loss": self.test_loss,
            "train_F1_score": self.train_f1_score,
            "test_F1_score": self.test_f1_score
        })
        output_path = file_path + f"loss_timestrech.log"
        print("EP:%d logs Saved on:" % epoch, output_path)
        df.to_csv(output_path)
def train(args, train_dataset, model, tokenizer, writer):

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size,
                                  collate_fn=collate_fn)

    train_total = len(
        train_dataloader
    ) // args.gradient_accumulation_steps * args.num_train_epochs
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=train_total)

    if os.path.isfile(os.path.join(
            args.pretrain_model_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(args.pretrain_model_path, "scheduler.pt")):
        optimizer.load_state_dict(
            torch.load(os.path.join(args.pretrain_model_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.pretrain_model_path, "scheduler.pt")))
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    print("***** Running training *****")

    global_step = 0
    steps_trained_in_current_epoch = 0

    if os.path.exists(args.pretrain_model_path
                      ) and "checkpoint" in args.pretrain_model_path:
        global_step = int(
            args.pretrain_model_path.split("-")[-1].split("/")[0])
        epochs_trained = global_step // (len(train_dataloader) //
                                         args.gradient_accumulation_steps)
        steps_trained_in_current_epoch = global_step % (
            len(train_dataloader) // args.gradient_accumulation_steps)

    train_loss, logging_loss = 0.0, 0.0
    model.zero_grad()

    for _ in range(int(args.num_train_epochs)):
        pbar = ProgressBar(n_total=len(train_dataloader), desc='Training')
        for step, batch in enumerate(train_dataloader):
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "start_positions": batch[3],
                "end_positions": batch[4]
            }

            inputs["token_type_ids"] = (batch[2] if args.model_type
                                        in ["bert"] else None)
            outputs = model(**inputs)
            loss = outputs[0]

            writer.add_scalar("Train_loss", loss.item(), step)

            if args.n_gpu > 1:
                loss = loss.mean()
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            pbar(step, {'loss': loss.item()})
            train_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                scheduler.step()
                optimizer.step()
                model.zero_grad()
                global_step += 1
                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    if args.local_rank == -1:
                        evaluate(args, model, tokenizer, writer)
                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir): os.makedirs(output_dir)
                    model_to_save = (model.module
                                     if hasattr(model, "module") else model)
                    model_to_save.save_pretrained(output_dir)
                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    tokenizer.save_vocabulary(output_dir)
                    print("Saving model checkpoint to %s", output_dir)
                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))

        print(" ")
        if 'cuda' in str(args.device):
            torch.cuda.empty_cache()
    return global_step, train_loss / global_step
Example #6
0
def train():
    """ Train the model using the parameters defined in the config file """
    print('Initialising ...')
    cfg = TrainConfig()
    checkpoint_folder = 'checkpoints/{}/'.format(cfg.experiment_name)

    if not os.path.exists(checkpoint_folder):
        os.makedirs(checkpoint_folder)

    tb_folder = 'tb/{}/'.format(cfg.experiment_name)
    if not os.path.exists(tb_folder):
        os.makedirs(tb_folder)

    writer = SummaryWriter(logdir=tb_folder, flush_secs=30)
    model = ParrotModel().cuda().train()
    optimiser = AdamW(model.parameters(),
                      lr=cfg.initial_lr,
                      weight_decay=cfg.weight_decay)

    train_dataset = ParrotDataset(cfg.train_labels, cfg.mp3_folder)
    train_loader = DataLoader(train_dataset,
                              batch_size=cfg.batch_size,
                              num_workers=cfg.workers,
                              collate_fn=parrot_collate_function,
                              pin_memory=True)

    val_dataset = ParrotDataset(cfg.val_labels, cfg.mp3_folder)
    val_loader = DataLoader(val_dataset,
                            batch_size=cfg.batch_size,
                            num_workers=cfg.workers,
                            collate_fn=parrot_collate_function,
                            shuffle=False,
                            pin_memory=True)

    epochs = cfg.epochs
    init_loss, step = 0., 0
    avg_loss = AverageMeter()
    print('Starting training')
    for epoch in range(epochs):
        loader_length = len(train_loader)
        epoch_start = time.time()

        for batch_idx, batch in enumerate(train_loader):
            optimiser.zero_grad()

            # VRAM control by skipping long examples
            if batch['spectrograms'].shape[-1] > cfg.max_time:
                continue

            # inference
            target = batch['targets'].cuda()
            model_input = batch['spectrograms'].cuda()
            model_output = model(model_input)

            # loss
            input_lengths = batch['input_lengths'].cuda()
            target_lengths = batch['target_lengths'].cuda()
            loss = ctc_loss(model_output, target, input_lengths,
                            target_lengths)
            loss.backward()

            if epoch == 0 and batch_idx == 0:
                init_loss = loss

            # logging
            elapsed = time.time() - epoch_start
            progress = batch_idx / loader_length
            est = datetime.timedelta(
                seconds=int(elapsed / progress)) if progress > 0.001 else '-'
            avg_loss.update(loss)
            suffix = '\tloss {:.4f}/{:.4f}\tETA [{}/{}]'.format(
                avg_loss.avg, init_loss,
                datetime.timedelta(seconds=int(elapsed)), est)
            printProgressBar(batch_idx,
                             loader_length,
                             suffix=suffix,
                             prefix='Epoch [{}/{}]\tStep [{}/{}]'.format(
                                 epoch, epochs, batch_idx, loader_length))

            writer.add_scalar('Steps/train_loss', loss, step)

            # saving the model
            if step % cfg.checkpoint_every == 0:
                test_name = '{}/test_epoch{}.mp3'.format(
                    checkpoint_folder, epoch)
                test_mp3_file(cfg.test_mp3, model, test_name)
                checkpoint_name = '{}/epoch_{}.pth'.format(
                    checkpoint_folder, epoch)
                torch.save(
                    {
                        'model': model.state_dict(),
                        'epoch': epoch,
                        'batch_idx': loader_length,
                        'step': step,
                        'optimiser': optimiser.state_dict()
                    }, checkpoint_name)

            # validating
            if step % cfg.val_every == 0:
                val(model, val_loader, writer, step)
                model = model.train()

            step += 1
            optimiser.step()

        # end of epoch
        print('')
        writer.add_scalar('Epochs/train_loss', avg_loss.avg, epoch)
        avg_loss.reset()
        test_name = '{}/test_epoch{}.mp3'.format(checkpoint_folder, epoch)
        test_mp3_file(cfg.test_mp3, model, test_name)
        checkpoint_name = '{}/epoch_{}.pth'.format(checkpoint_folder, epoch)
        torch.save(
            {
                'model': model.state_dict(),
                'epoch': epoch,
                'batch_idx': loader_length,
                'step': step,
                'optimiser': optimiser.state_dict()
            }, checkpoint_name)

    # finished training
    writer.close()
    print('Training finished :)')
Example #7
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt")):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to global_step of last saved checkpoint from model path
        try:
            global_step = int(
                args.model_name_or_path.split("-")[-1].split("/")[0])
        except ValueError:
            global_step = 0
        epochs_trained = global_step // (len(train_dataloader) //
                                         args.gradient_accumulation_steps)
        steps_trained_in_current_epoch = global_step % (
            len(train_dataloader) // args.gradient_accumulation_steps)

        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from epoch %d", epochs_trained)
        logger.info("  Continuing training from global step %d", global_step)
        logger.info("  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)

    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(
        epochs_trained,
        int(args.num_train_epochs),
        desc="Epoch",
        disable=args.local_rank not in [-1, 0],
    )
    set_seed(args)  # Added here for reproductibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "labels": batch[3]
            }
            inputs["token_type_ids"] = (
                batch[2]
                if args.model_type in ["bert", "xlnet", "albert"] else None
            )  # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
            outputs = model(**inputs)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            if step % 10 == 0:
                print(step, loss.item())

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    len(epoch_iterator) <= args.gradient_accumulation_steps and
                (step + 1) == len(epoch_iterator)):
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    logs = {}
                    if (
                            args.local_rank == -1
                            and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            eval_key = "eval_{}".format(key)
                            logs[eval_key] = value

                    loss_scalar = (tr_loss - logging_loss) / args.logging_steps
                    learning_rate_scalar = scheduler.get_lr()[0]
                    logs["learning_rate"] = learning_rate_scalar
                    logs["loss"] = loss_scalar
                    logging_loss = tr_loss

                    print(json.dumps({**logs, **{"step": global_step}}))

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    return global_step, tr_loss / global_step
def main():
    args = parseArguments()

    os.makedirs(args.modelDir, exist_ok=True)
    checkpointDir = os.path.join(args.modelDir, 'checkpoints')
    os.makedirs(checkpointDir, exist_ok=True)

    os.makedirs(args.ensembleDir, exist_ok=True)

    with EventTimer('Preparing for dataset / dataloader'):
        trainDataset = ProductDataset(os.path.join(args.dataDir, 'train'),
                                      os.path.join(args.trainImages),
                                      transform=trainingPreprocessing)
        validDataset = ProductDataset(os.path.join(args.dataDir, 'train'),
                                      os.path.join(args.validImages),
                                      transform=inferencePreprocessing)

        trainDataloader = DataLoader(trainDataset,
                                     batch_size=args.batchSize,
                                     num_workers=args.numWorkers,
                                     shuffle=True)
        validDataloader = DataLoader(validDataset,
                                     batch_size=args.batchSize,
                                     num_workers=args.numWorkers,
                                     shuffle=False)

        print(f'> Training dataset:\t{len(trainDataset)}')
        print(f'> Validation dataset:\t{len(validDataset)}')

    with EventTimer(f'Load pretrained model - {args.pretrainModel}'):
        model = models.GetPretrainedModel(args.pretrainModel,
                                          fcDims=args.fcDims + [42])
        print(model)
        #torchsummary will crash under densenet, skip the summary.
        #torchsummary.summary(model, (3, 224, 224), device='cpu')

    with EventTimer(f'Train model'):
        model.cuda()

        criterion = CrossEntropyLoss()
        optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.l2)
        scheduler = CosineAnnealingLR(optimizer,
                                      T_max=args.epochs,
                                      eta_min=1e-6)
        history = []

        if args.retrain != 0:
            checkpoint = torch.load(
                os.path.join(checkpointDir,
                             f'checkpoint-{args.retrain:03d}.pt'))
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            history = checkpoint['history']

        def runEpoch(dataloader, train=False, name=''):
            # For empty validation dataloader
            if len(dataloader) == 0:
                return 0, 0

            # Enable grad
            with (torch.enable_grad() if train else torch.no_grad()):
                if train: model.train()
                else: model.eval()

                losses = []
                for img, label, imgPath in tqdm(dataloader,
                                                desc=name,
                                                ncols=80):
                    if train:
                        optimizer.zero_grad()

                    output = model(img.cuda()).cpu()
                    loss = criterion(output, label)

                    if train:
                        loss.backward()
                        optimizer.step()

                    accu = accuracy(output.data.numpy(), label.numpy())
                    losses.append((loss.item(), accu))

            return map(np.mean, zip(*losses))

        def cleanUp():
            model.eval()
            train_pred = np.zeros((trainDataloader.__len__()) * args.batchSize)
            cnt = 0
            for i, (data, label, path) in enumerate(trainDataloader):
                test_pred = model(data.cuda())
                pred = np.max(test_pred.cpu().data.numpy(), axis=1)
                train_pred[cnt:cnt + len(pred)] = pred
                cnt += len(pred)

            sorted_pred = train_pred
            sorted_pred.sort()
            threshold = sorted_pred[(len(sorted_pred) // 20)]
            data_set = [[], []]

            for i, (data, label, path) in enumerate(trainDataloader):
                test_pred = model(data.cuda())
                pred = np.max(test_pred.cpu().data.numpy(), axis=1)
                for j in range(len(pred)):
                    if pred[j] >= threshold:
                        data_set[0].append(path[j])
                        data_set[1].append(label[j])

            newDataset = ProductDataset(os.path.join(args.dataDir, 'train'),
                                        os.path.join(args.trainImages),
                                        transform=trainingPreprocessing,
                                        data=data_set)
            newDataloader = DataLoader(newDataset,
                                       batch_size=args.batchSize,
                                       num_workers=args.numWorkers,
                                       shuffle=True)

            print(
                f"{newDataloader.__len__() * args.batchSize} images remain after cleanup"
            )
            return newDataloader

        for epoch in range(args.retrain + 1, args.epochs + 1):
            with EventTimer(verbose=False) as et:
                print(f'====== Epoch {epoch:3d} / {args.epochs:3d} ======')
                trainLoss, trainAccu = runEpoch(trainDataloader,
                                                train=True,
                                                name='training  ')
                validLoss, validAccu = runEpoch(validDataloader,
                                                name='validation')

                history.append(
                    ((trainLoss, trainAccu), (validLoss, validAccu)))

                scheduler.step()
                print(
                    f'[{et.gettime():.4f}s] Training: {trainLoss:.6f} / {trainAccu:.4f} ; Validation {validLoss:.6f} / {validAccu:.4f}'
                )

            if args.cleanup and epoch % args.cleanup_epoch == 0:
                with EventTimer('Cleaning Training Set'):
                    trainDataloader = cleanUp()

            if epoch % 5 == 0:
                torch.save(
                    {
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict(),
                        'history': history,
                    }, os.path.join(checkpointDir,
                                    f'checkpoint-{epoch:03d}.pt'))

        # save model as its coressponding name
        torch.save(model.state_dict(),
                   os.path.join(args.modelDir, 'model-weights.pt'))
        utils.pickleSave(history, os.path.join(args.modelDir, 'history.pkl'))
Example #9
0
class Trainer():
    def __init__(self, alphabets_, list_ngram):

        self.vocab = Vocab(alphabets_)
        self.synthesizer = SynthesizeData(vocab_path="")
        self.list_ngrams_train, self.list_ngrams_valid = self.train_test_split(
            list_ngram, test_size=0.1)
        print("Loaded data!!!")
        print("Total training samples: ", len(self.list_ngrams_train))
        print("Total valid samples: ", len(self.list_ngrams_valid))

        INPUT_DIM = self.vocab.__len__()
        OUTPUT_DIM = self.vocab.__len__()

        self.device = DEVICE
        self.num_iters = NUM_ITERS
        self.beamsearch = BEAM_SEARCH

        self.batch_size = BATCH_SIZE
        self.print_every = PRINT_PER_ITER
        self.valid_every = VALID_PER_ITER

        self.checkpoint = CHECKPOINT
        self.export_weights = EXPORT
        self.metrics = MAX_SAMPLE_VALID
        logger = LOG

        if logger:
            self.logger = Logger(logger)

        self.iter = 0

        self.model = Seq2Seq(input_dim=INPUT_DIM,
                             output_dim=OUTPUT_DIM,
                             encoder_embbeded=ENC_EMB_DIM,
                             decoder_embedded=DEC_EMB_DIM,
                             encoder_hidden=ENC_HID_DIM,
                             decoder_hidden=DEC_HID_DIM,
                             encoder_dropout=ENC_DROPOUT,
                             decoder_dropout=DEC_DROPOUT)

        self.optimizer = AdamW(self.model.parameters(),
                               betas=(0.9, 0.98),
                               eps=1e-09)
        self.scheduler = OneCycleLR(self.optimizer,
                                    total_steps=self.num_iters,
                                    pct_start=PCT_START,
                                    max_lr=MAX_LR)

        self.criterion = LabelSmoothingLoss(len(self.vocab),
                                            padding_idx=self.vocab.pad,
                                            smoothing=0.1)

        self.train_gen = self.data_gen(self.list_ngrams_train,
                                       self.synthesizer,
                                       self.vocab,
                                       is_train=True)
        self.valid_gen = self.data_gen(self.list_ngrams_valid,
                                       self.synthesizer,
                                       self.vocab,
                                       is_train=False)

        self.train_losses = []

        # to device
        self.model.to(self.device)
        self.criterion.to(self.device)

    def train_test_split(self, list_phrases, test_size=0.1):
        list_phrases = list_phrases
        train_idx = int(len(list_phrases) * (1 - test_size))
        list_phrases_train = list_phrases[:train_idx]
        list_phrases_valid = list_phrases[train_idx:]
        return list_phrases_train, list_phrases_valid

    def data_gen(self, list_ngrams_np, synthesizer, vocab, is_train=True):
        dataset = AutoCorrectDataset(list_ngrams_np,
                                     transform_noise=synthesizer,
                                     vocab=vocab,
                                     maxlen=MAXLEN)

        shuffle = True if is_train else False
        gen = DataLoader(dataset,
                         batch_size=BATCH_SIZE,
                         shuffle=shuffle,
                         drop_last=False)

        return gen

    def step(self, batch):
        self.model.train()

        batch = self.batch_to_device(batch)
        src, tgt = batch['src'], batch['tgt']
        src, tgt = src.transpose(1, 0), tgt.transpose(
            1, 0)  # batch x src_len -> src_len x batch

        outputs = self.model(
            src, tgt)  # src : src_len x B, outpus : B x tgt_len x vocab

        #        loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)'))
        outputs = outputs.view(-1, outputs.size(2))  # flatten(0, 1)

        tgt_output = tgt.transpose(0, 1).reshape(
            -1)  # flatten()   # tgt: tgt_len xB , need convert to B x tgt_len

        loss = self.criterion(outputs, tgt_output)

        self.optimizer.zero_grad()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)

        self.optimizer.step()
        self.scheduler.step()

        loss_item = loss.item()

        return loss_item

    def train(self):
        print("Begin training from iter: ", self.iter)
        total_loss = 0

        total_loader_time = 0
        total_gpu_time = 0
        best_acc = -1

        data_iter = iter(self.train_gen)
        for i in range(self.num_iters):
            self.iter += 1

            start = time.time()

            try:
                batch = next(data_iter)
            except StopIteration:
                data_iter = iter(self.train_gen)
                batch = next(data_iter)

            total_loader_time += time.time() - start

            start = time.time()
            loss = self.step(batch)
            total_gpu_time += time.time() - start

            total_loss += loss
            self.train_losses.append((self.iter, loss))

            if self.iter % self.print_every == 0:
                info = 'iter: {:06d} - train loss: {:.3f} - lr: {:.2e} - load time: {:.2f} - gpu time: {:.2f}'.format(
                    self.iter, total_loss / self.print_every,
                    self.optimizer.param_groups[0]['lr'], total_loader_time,
                    total_gpu_time)

                total_loss = 0
                total_loader_time = 0
                total_gpu_time = 0
                print(info)
                self.logger.log(info)

            if self.iter % self.valid_every == 0:
                val_loss, preds, actuals, inp_sents = self.validate()
                acc_full_seq, acc_per_char, cer = self.precision(self.metrics)

                info = 'iter: {:06d} - valid loss: {:.3f} - acc full seq: {:.4f} - acc per char: {:.4f} - CER: {:.4f} '.format(
                    self.iter, val_loss, acc_full_seq, acc_per_char, cer)
                print(info)
                print("--- Sentence predict ---")
                for pred, inp, label in zip(preds, inp_sents, actuals):
                    infor_predict = 'Pred: {} - Inp: {} - Label: {}'.format(
                        pred, inp, label)
                    print(infor_predict)
                    self.logger.log(infor_predict)
                self.logger.log(info)

                if acc_full_seq > best_acc:
                    self.save_weights(self.export_weights)
                    best_acc = acc_full_seq
                self.save_checkpoint(self.checkpoint)

    def validate(self):
        self.model.eval()

        total_loss = []
        max_step = self.metrics / self.batch_size
        with torch.no_grad():
            for step, batch in enumerate(self.valid_gen):
                batch = self.batch_to_device(batch)
                src, tgt = batch['src'], batch['tgt']
                src, tgt = src.transpose(1, 0), tgt.transpose(1, 0)

                outputs = self.model(src, tgt, 0)  # turn off teaching force

                outputs = outputs.flatten(0, 1)
                tgt_output = tgt.flatten()
                loss = self.criterion(outputs, tgt_output)

                total_loss.append(loss.item())

                preds, actuals, inp_sents, probs = self.predict(5)

                del outputs
                del loss
                if step > max_step:
                    break

        total_loss = np.mean(total_loss)
        self.model.train()

        return total_loss, preds[:3], actuals[:3], inp_sents[:3]

    def predict(self, sample=None):
        pred_sents = []
        actual_sents = []
        inp_sents = []

        for batch in self.valid_gen:
            batch = self.batch_to_device(batch)

            if self.beamsearch:
                translated_sentence = batch_translate_beam_search(
                    batch['src'], self.model)
                prob = None
            else:
                translated_sentence, prob = translate(batch['src'], self.model)

            pred_sent = self.vocab.batch_decode(translated_sentence.tolist())
            actual_sent = self.vocab.batch_decode(batch['tgt'].tolist())
            inp_sent = self.vocab.batch_decode(batch['src'].tolist())

            pred_sents.extend(pred_sent)
            actual_sents.extend(actual_sent)
            inp_sents.extend(inp_sent)

            if sample is not None and len(pred_sents) > sample:
                break

        return pred_sents, actual_sents, inp_sents, prob

    def precision(self, sample=None):

        pred_sents, actual_sents, _, _ = self.predict(sample=sample)

        acc_full_seq = compute_accuracy(actual_sents,
                                        pred_sents,
                                        mode='full_sequence')
        acc_per_char = compute_accuracy(actual_sents,
                                        pred_sents,
                                        mode='per_char')
        cer = compute_accuracy(actual_sents, pred_sents, mode='CER')

        return acc_full_seq, acc_per_char, cer

    def visualize_prediction(self,
                             sample=16,
                             errorcase=False,
                             fontname='serif',
                             fontsize=16):

        pred_sents, actual_sents, img_files, probs = self.predict(sample)

        if errorcase:
            wrongs = []
            for i in range(len(img_files)):
                if pred_sents[i] != actual_sents[i]:
                    wrongs.append(i)

            pred_sents = [pred_sents[i] for i in wrongs]
            actual_sents = [actual_sents[i] for i in wrongs]
            img_files = [img_files[i] for i in wrongs]
            probs = [probs[i] for i in wrongs]

        img_files = img_files[:sample]

        fontdict = {'family': fontname, 'size': fontsize}

    def visualize_dataset(self, sample=16, fontname='serif'):
        n = 0
        for batch in self.train_gen:
            for i in range(self.batch_size):
                img = batch['img'][i].numpy().transpose(1, 2, 0)
                sent = self.vocab.decode(batch['tgt_input'].T[i].tolist())

                n += 1
                if n >= sample:
                    return

    def load_checkpoint(self, filename):
        checkpoint = torch.load(filename)

        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.scheduler.load_state_dict(checkpoint['scheduler'])
        self.model.load_state_dict(checkpoint['state_dict'])
        self.iter = checkpoint['iter']

        self.train_losses = checkpoint['train_losses']

    def save_checkpoint(self, filename):
        state = {
            'iter': self.iter,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'train_losses': self.train_losses,
            'scheduler': self.scheduler.state_dict()
        }

        path, _ = os.path.split(filename)
        os.makedirs(path, exist_ok=True)

        torch.save(state, filename)

    def load_weights(self, filename):
        state_dict = torch.load(filename,
                                map_location=torch.device(self.device))

        for name, param in self.model.named_parameters():
            if name not in state_dict:
                print('{} not found'.format(name))
            elif state_dict[name].shape != param.shape:
                print('{} missmatching shape, required {} but found {}'.format(
                    name, param.shape, state_dict[name].shape))
                del state_dict[name]

        self.model.load_state_dict(state_dict, strict=False)

    def save_weights(self, filename):
        path, _ = os.path.split(filename)
        os.makedirs(path, exist_ok=True)

        torch.save(self.model.state_dict(), filename)

    def batch_to_device(self, batch):

        src = batch['src'].to(self.device, non_blocking=True)
        tgt = batch['tgt'].to(self.device, non_blocking=True)

        batch = {'src': src, 'tgt': tgt}

        return batch
Example #10
0
class Trainer:
    """
    Handles model training and evaluation.
    
    Arguments:
    ----------
    config: A dictionary of training parameters, likely from a .yaml
    file
    
    model: A pytorch segmentation model (e.g. DeepLabV3)
    
    trn_data: A pytorch dataloader object that will return pairs of images and
    segmentation masks from a training dataset
    
    val_data: A pytorch dataloader object that will return pairs of images and
    segmentation masks from a validation dataset.
    
    """
    def __init__(self, config, model, trn_data, val_data=None):
        self.config = config
        self.model = model.cuda()
        self.trn_data = DataFetcher(trn_data)
        self.val_data = val_data

        #create the optimizer
        if config['optim'] == 'SGD':
            self.optimizer = SGD(model.parameters(),
                                 lr=config['lr'],
                                 momentum=config['momentum'],
                                 weight_decay=config['wd'])
        elif config['optim'] == 'AdamW':
            self.optimizer = AdamW(
                model.parameters(), lr=config['lr'],
                weight_decay=config['wd'])  #momentum is default
        else:
            optim = config['optim']
            raise Exception(
                f'Optimizer {optim} is not supported! Must be SGD or AdamW')

        #create the learning rate scheduler
        schedule = config['lr_policy']
        if schedule == 'OneCycle':
            self.scheduler = OneCycleLR(self.optimizer,
                                        config['lr'],
                                        total_steps=config['iters'])
        elif schedule == 'MultiStep':
            self.scheduler = MultiStepLR(self.optimizer,
                                         milestones=config['lr_decay_epochs'])
        elif schedule == 'Poly':
            func = lambda iteration: (1 - (iteration / config['iters'])
                                      )**config['power']
            self.scheduler = LambdaLR(self.optimizer, func)
        else:
            lr_policy = config['lr_policy']
            raise Exception(
                f'Policy {lr_policy} is not supported! Must be OneCycle, MultiStep or Poly'
            )

        #create the loss criterion
        if config['num_classes'] > 1:
            #load class weights if they were given in the config file
            if 'class_weights' in config:
                weight = torch.Tensor(config['class_weights']).float().cuda()
            else:
                weight = None

            self.criterion = nn.CrossEntropyLoss(weight=weight).cuda()
        else:
            self.criterion = nn.BCEWithLogitsLoss().cuda()

        #define train and validation metrics and class names
        class_names = config['class_names']

        #make training metrics using the EMAMeter. this meter gives extra
        #weight to the most recent metric values calculated during training
        #this gives a better reflection of how well the model is performing
        #when the metrics are printed
        trn_md = {
            name: metric_lookup[name](EMAMeter())
            for name in config['metrics']
        }
        self.trn_metrics = ComposeMetrics(trn_md, class_names)
        self.trn_loss_meter = EMAMeter()

        #the only difference between train and validation metrics
        #is that we use the AverageMeter. this is because there are
        #no weight updates during evaluation, so all batches should
        #count equally
        val_md = {
            name: metric_lookup[name](AverageMeter())
            for name in config['metrics']
        }
        self.val_metrics = ComposeMetrics(val_md, class_names)
        self.val_loss_meter = AverageMeter()

        self.logging = config['logging']

        #now, if we're resuming from a previous run we need to load
        #the state for the model, optimizer, and schedule and resume
        #the mlflow run (if there is one and we're using logging)
        if config['resume']:
            self.resume(config['resume'])
        elif self.logging:
            #if we're not resuming, but are logging, then we
            #need to setup mlflow with a new experiment
            #everytime that Trainer is instantiated we want to
            #end the current active run and let a new one begin
            mlflow.end_run()

            #extract the experiment name from config so that
            #we know where to save our files, if experiment name
            #already exists, we'll use it, otherwise we create a
            #new experiment
            mlflow.set_experiment(self.config['experiment_name'])

            #add the config file as an artifact
            mlflow.log_artifact(config['config_file'])

            #we don't want to add everything in the config
            #to mlflow parameters, we'll just add the most
            #likely to change parameters
            mlflow.log_param('lr_policy', config['lr_policy'])
            mlflow.log_param('optim', config['optim'])
            mlflow.log_param('lr', config['lr'])
            mlflow.log_param('wd', config['wd'])
            mlflow.log_param('bsz', config['bsz'])
            mlflow.log_param('momentum', config['momentum'])
            mlflow.log_param('iters', config['iters'])
            mlflow.log_param('epochs', config['epochs'])
            mlflow.log_param('encoder', config['encoder'])
            mlflow.log_param('finetune_layer', config['finetune_layer'])
            mlflow.log_param('pretraining', config['pretraining'])

    def resume(self, checkpoint_fpath):
        """
        Sets model parameters, scheduler and optimizer states to the
        last recorded values in the given checkpoint file.
        """
        checkpoint = torch.load(checkpoint_fpath, map_location='cpu')
        self.model.load_state_dict(checkpoint['state_dict'])

        if not self.config['restart_training']:
            self.scheduler.load_state_dict(checkpoint['scheduler'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])

        if self.logging and 'run_id' in checkpoint:
            mlflow.start_run(run_id=checkpoint['run_id'])

        print(f'Loaded state from {checkpoint_fpath}')
        print(f'Resuming from epoch {self.scheduler.last_epoch}...')

    def log_metrics(self, step, dataset):
        #get the corresponding losses and metrics dict for
        #either train or validation sets
        if dataset == 'train':
            losses = self.trn_loss_meter
            metric_dict = self.trn_metrics.metrics_dict
        elif dataset == 'valid':
            losses = self.val_loss_meter
            metric_dict = self.val_metrics.metrics_dict

        #log the last loss, using the dataset name as a prefix
        mlflow.log_metric(dataset + '_loss', losses.avg, step=step)

        #log all the metrics in our dict, using dataset as a prefix
        metrics = {}
        for k, v in metric_dict.items():
            values = v.meter.avg
            for class_name, val in zip(self.trn_metrics.class_names, values):
                metrics[dataset + '_' + class_name + '_' + k] = float(
                    val.item())

        mlflow.log_metrics(metrics, step=step)

    def train(self):
        """
        Defines a pytorch style training loop for the model withtqdm progress bar
        for each epoch and handles printing loss/metrics at the end of each epoch.
        
        epochs: Number of epochs to train model
        train_iters_per_epoch: Number of training iterations is each epoch. Reducing this 
        number will give more frequent updates but result in slower training time.
        
        Results:
        ----------
        
        After train_iters_per_epoch iterations are completed, it will evaluate the model
        on val_data if there is any, then prints loss and metrics for train and validation
        datasets.
        """

        #set the inner and outer training loop as either
        #iterations or epochs depending on our scheduler
        if self.config['lr_policy'] != 'MultiStep':
            last_epoch = self.scheduler.last_epoch + 1
            total_epochs = self.config['iters']
            iters_per_epoch = 1
            outer_loop = tqdm(range(last_epoch, total_epochs + 1),
                              file=sys.stdout,
                              initial=last_epoch,
                              total=total_epochs)
            inner_loop = range(iters_per_epoch)
        else:
            last_epoch = self.scheduler.last_epoch + 1
            total_epochs = self.config['epochs']
            iters_per_epoch = len(self.trn_data)
            outer_loop = range(last_epoch, total_epochs + 1)
            inner_loop = tqdm(range(iters_per_epoch), file=sys.stdout)

        #determine the epochs at which to print results
        eval_epochs = total_epochs // self.config['num_prints']
        save_epochs = total_epochs // self.config['num_save_checkpoints']

        #the cudnn.benchmark flag speeds up performance
        #when the model input size is constant. See:
        #https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
        cudnn.benchmark = True

        #perform training over the outer and inner loops
        for epoch in outer_loop:
            for iteration in inner_loop:
                #load the next batch of training data
                images, masks = self.trn_data.load()

                #run the training iteration
                loss, output = self._train_1_iteration(images, masks)

                #record the loss and evaluate metrics
                self.trn_loss_meter.update(loss)
                self.trn_metrics.evaluate(output, masks)

            #when we're at an eval_epoch we want to print
            #the training results so far and then evaluate
            #the model on the validation data
            if epoch % eval_epochs == 0:
                #before printing results let's record everything in mlflow
                #(if we're using logging)
                if self.logging:
                    self.log_metrics(epoch, dataset='train')

                print('\n')  #print a new line to give space from progess bar
                print(f'train_loss: {self.trn_loss_meter.avg:.3f}')
                self.trn_loss_meter.reset()
                #prints and automatically resets the metric averages to 0
                self.trn_metrics.print()

                #run evaluation if we have validation data
                if self.val_data is not None:
                    #before evaluation we want to turn off cudnn
                    #benchmark because the input sizes of validation
                    #images are not necessarily constant
                    cudnn.benchmark = False
                    self.evaluate()

                    if self.logging:
                        self.log_metrics(epoch, dataset='valid')

                    print(
                        '\n')  #print a new line to give space from progess bar
                    print(f'valid_loss: {self.val_loss_meter.avg:.3f}')
                    self.val_loss_meter.reset()
                    #prints and automatically resets the metric averages to 0
                    self.val_metrics.print()

                    #turn cudnn.benchmark back on before returning to training
                    cudnn.benchmark = True

            #update the optimizer schedule
            self.scheduler.step()

            #the last step is to save the training state if
            #at a checkpoint
            if epoch % save_epochs == 0:
                self.save_state(epoch)

    def _train_1_iteration(self, images, masks):
        #run a training step
        self.model.train()
        self.optimizer.zero_grad()

        #forward pass
        output = self.model(images)
        loss = self.criterion(output, masks)

        #backward pass
        loss.backward()
        self.optimizer.step()

        #return the loss value and the output
        return loss.item(), output.detach()

    def evaluate(self):
        """
        Evaluation method used at the end of each epoch. Not intended to
        generate predictions for validation dataset, it only returns average loss
        and stores metrics for validaiton dataset.
        
        Use Validator class for generating masks on a dataset.
        """
        #set the model into eval mode
        self.model.eval()

        val_iter = DataFetcher(self.val_data)
        for _ in range(len(val_iter)):
            with torch.no_grad():
                #load batch of data
                images, masks = val_iter.load()
                output = self.model.eval()(images)
                loss = self.criterion(output, masks)
                self.val_loss_meter.update(loss.item())
                self.val_metrics.evaluate(output.detach(), masks)

        #loss and metrics are updated inplace, so there's nothing to return
        return None

    def save_state(self, epoch):
        """
        Saves the self.model state dict
        
        Arguments:
        ------------
        
        save_path: Path of .pt file for saving
        
        Example:
        ----------
        
        trainer = Trainer(...)
        trainer.save_model(model_path + 'new_model.pt')
        """

        #save the state together with the norms that we're using
        state = {
            'state_dict': self.model.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'norms': self.config['training_norms']
        }

        if self.logging:
            state['run_id'] = mlflow.active_run().info.run_id

        #the last step is to create the name of the file to save
        #the format is: name-of-experiment_pretraining_epoch.pth
        model_dir = self.config['model_dir']
        exp_name = self.config['experiment_name']
        pretraining = self.config['pretraining']
        ft_layer = self.config['finetune_layer']

        if self.config['lr_policy'] != 'MultiStep':
            total_epochs = self.config['iters']
        else:
            total_epochs = self.config['epochs']

        if os.path.isfile(pretraining):
            #this is slightly clunky, but it handles the case
            #of using custom pretrained weights from a file
            #usually there aren't any '.'s other than the file
            #extension
            pretraining = pretraining.split('/')[-2]  #.split('.')[0]

        save_path = os.path.join(
            model_dir,
            f'{exp_name}-{pretraining}_ft_{ft_layer}_epoch{epoch}_of_{total_epochs}.pth'
        )
        torch.save(state, save_path)
Example #11
0
class Detector(object):
    def __init__(self, cfg):
        self.device = cfg["device"]
        self.model = Models().get_model(cfg["network"]) # cfg.network
        self.model.to(self.device)
        params = [p for p in self.model.parameters() if p.requires_grad]
        self.optimizer = AdamW(params, lr=0.00001)
        self.lr_scheduler = OneCycleLR(self.optimizer,
                                       max_lr=1e-4,
                                       epochs=cfg["nepochs"],
                                       steps_per_epoch=169,  # len(dataloader)/accumulations
                                       div_factor=25,  # for initial lr, default: 25
                                       final_div_factor=1e3,  # for final lr, default: 1e4
                                       )

    def fit(self, data_loader, accumulation_steps=4, wandb=None):
        self.model.train()
        #     metric_logger = utils.MetricLogger(delimiter="  ")
        #     metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
        avg_loss = MetricLogger('scalar')
        total_loss = MetricLogger('dict')
        lr_log = MetricLogger('list')

        self.optimizer.zero_grad()
        device = self.device

        for i, (images, targets) in enumerate(data_loader):
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            loss_dict = self.model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            loss_value = losses.detach().item()
            if not math.isfinite(loss_value):
                print("Loss is {}, stopping training".format(loss_value))
                sys.exit(1)

            losses.backward()
            if (i+1) % accumulation_steps == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()

                if self.lr_scheduler is not None:
                    self.lr_scheduler.step()
                    lr_log.update(self.lr_scheduler.get_last_lr())


            print(f"\rTrain iteration: [{i+1}/{len(data_loader)}]", end="")
            avg_loss.update(loss_value)
            total_loss.update(loss_dict)

            # metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
            # metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        print()
        #print(loss_dict)
        return {"train_avg_loss": avg_loss.avg}, total_loss.avg


    def mixup_fit(self, data_loader, accumulation_steps=4, wandb=None):
        self.model.train()
        torch.cuda.empty_cache()
        #     metric_logger = utils.MetricLogger(delimiter="  ")
        #     metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
        avg_loss = MetricLogger('scalar')
        total_loss = MetricLogger('dict')
        #lr_log = MetricLogger('list')

        self.optimizer.zero_grad()
        device = self.device

        for i, (batch1, batch2) in enumerate(data_loader):
            images1, targets1 = batch1
            images2, targets2 = batch2
            images = mixup_images(images1, images2)
            targets = merge_targets(targets1, targets2)
            del images1, images2, targets1, targets2, batch1, batch2

            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            loss_dict = self.model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            loss_value = losses.detach().item()
            if not math.isfinite(loss_value):
                print("Loss is {}, stopping training".format(loss_value))
                sys.exit(1)

            losses.backward()
            if (i+1) % accumulation_steps == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()

                if self.lr_scheduler is not None:
                    self.lr_scheduler.step()
                    #lr_log.update(self.lr_scheduler.get_last_lr())


            print(f"Train iteration: [{i+1}/{674}]\r", end="")
            avg_loss.update(loss_value)
            total_loss.update(loss_dict)

            # metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
            # metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        print()
        #print(loss_dict)
        return {"train_avg_loss": avg_loss.avg}, total_loss.avg


    def evaluate(self, val_dataloader):
        device = self.device
        torch.cuda.empty_cache()
        # self.model.to(device)
        self.model.eval()
        mAp_logger = MetricLogger('list')
        with torch.no_grad():
            for (j, batch) in enumerate(val_dataloader):
                print(f"\rValidation: [{j+1}/{len(val_dataloader)}]", end="")
                images, targets = batch
                del batch
                images = [img.to(device) for img in images]
                # targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
                predictions = self.model(images)#, targets)
                for i, pred in enumerate(predictions):
                    probas = pred["scores"].detach().cpu().numpy()
                    mask = probas > 0.6
                    preds = pred["boxes"].detach().cpu().numpy()[mask]
                    gts = targets[i]["boxes"].detach().cpu().numpy()
                    score, scores = map_score(gts, preds, thresholds=[.5, .55, .6, .65, .7, .75])
                    mAp_logger.update(scores)
            print()
        return {"validation_mAP_score": mAp_logger.avg}

    def get_checkpoint(self):
        self.model.eval()
        model_state = self.model.state_dict()
        optimizer_state = self.optimizer.state_dict()
        checkpoint = {'model_state_dict': model_state,
                      'optimizer_state_dict': optimizer_state
                      }
        # if self.lr_scheduler:
        #     scheduler_state = self.lr_scheduler.state_dict()
        #     checkpoint['lr_scheduler_state_dict'] = scheduler_state

        return checkpoint

    def load_checkpoint(self, checkpoint):
        self.model.eval()
        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
def main() -> None:
    global best_loss

    args = parser.parse_args()

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    start_epoch = 0

    vcf_reader = VCFReader(args.train_data, args.classification_map,
                           args.chromosome, args.class_hierarchy)
    vcf_writer = vcf_reader.get_vcf_writer()
    train_dataset, validation_dataset = vcf_reader.get_datasets(
        args.validation_split)
    train_sampler = BatchByLabelRandomSampler(args.batch_size,
                                              train_dataset.labels)
    train_loader = DataLoader(train_dataset, batch_sampler=train_sampler)

    if args.validation_split != 0:
        validation_sampler = BatchByLabelRandomSampler(
            args.batch_size, validation_dataset.labels)
        validation_loader = DataLoader(validation_dataset,
                                       batch_sampler=validation_sampler)

    kwargs = {
        'total_size': vcf_reader.positions.shape[0],
        'window_size': args.window_size,
        'num_layers': args.layers,
        'num_classes': len(vcf_reader.label_encoder.classes_),
        'num_super_classes': len(vcf_reader.super_label_encoder.classes_)
    }
    model = WindowedMLP(**kwargs)
    model.to(get_device(args))

    optimizer = AdamW(model.parameters(), lr=args.learning_rate)

    #######
    if args.resume_path is not None:
        if os.path.isfile(args.resume_path):
            print("=> loading checkpoint '{}'".format(args.resume_path))
            checkpoint = torch.load(args.resume_path)
            if kwargs != checkpoint['model_kwargs']:
                raise ValueError(
                    'The checkpoint\'s kwargs don\'t match the ones used to initialize the model'
                )
            if vcf_reader.snps.shape[0] != checkpoint['vcf_writer'].snps.shape[
                    0]:
                raise ValueError(
                    'The data on which the checkpoint was trained had a different number of snp positions'
                )
            start_epoch = checkpoint['epoch']
            best_loss = checkpoint['best_loss']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume_path, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    #############

    if args.validate:
        validate(validation_loader, model,
                 nn.functional.binary_cross_entropy_with_logits,
                 len(vcf_reader.label_encoder.classes_),
                 len(vcf_reader.super_label_encoder.classes_), vcf_reader.maf,
                 args)
        return

    for epoch in range(start_epoch, args.epochs + start_epoch):
        loss = train(train_loader, model,
                     nn.functional.binary_cross_entropy_with_logits, optimizer,
                     len(vcf_reader.label_encoder.classes_),
                     len(vcf_reader.super_label_encoder.classes_),
                     vcf_reader.maf, epoch, args)

        if epoch % args.save_freq == 0 or epoch == args.epochs + start_epoch - 1:
            if args.validation_split != 0:
                validation_loss = validate(
                    validation_loader, model,
                    nn.functional.binary_cross_entropy_with_logits,
                    len(vcf_reader.label_encoder.classes_),
                    len(vcf_reader.super_label_encoder.classes_),
                    vcf_reader.maf, args)
                is_best = validation_loss < best_loss
                best_loss = min(validation_loss, best_loss)
            else:
                is_best = loss < best_loss
                best_loss = min(loss, best_loss)

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'model_kwargs': kwargs,
                    'best_loss': best_loss,
                    'optimizer': optimizer.state_dict(),
                    'vcf_writer': vcf_writer,
                    'label_encoder': vcf_reader.label_encoder,
                    'super_label_encoder': vcf_reader.super_label_encoder,
                    'maf': vcf_reader.maf
                }, is_best, args.chromosome, args.model_name, args.model_dir)
class SAJEM():
    '''
    Self-Attention based Joint Embedding Model
    Consist of 2 branches to encode image and text
    '''
    def __init__(self,
                 image_encoder,
                 text_encoder,
                 image_mha,
                 bert_model,
                 optimizer='adam',
                 lr=1e-3,
                 l2_regularization=1e-2,
                 margin_loss=1e-2,
                 max_violation=True,
                 cost_style='mean',
                 use_lr_scheduler=False,
                 grad_clip=0,
                 num_training_steps=30000,
                 device='cuda'):
        self.image_mha = image_mha
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder
        self.bert_model = bert_model
        self.device = device

        self.use_lr_scheduler = use_lr_scheduler
        self.params = []
        self.params = list(self.image_mha.parameters())
        self.params += list(self.text_encoder.parameters())
        self.params += list(self.image_encoder.parameters())
        self.params += list(self.bert_model.parameters())
        self.grad_clip = grad_clip
        self.frozen = False
        if optimizer == 'adamW':
            self.optimizer = AdamW([{
                'params':
                list(self.bert_model.parameters()),
                'lr':
                3e-5
            }, {
                'params':
                list(self.image_encoder.parameters()) +
                list(self.text_encoder.parameters()) +
                list(self.image_mha.parameters()),
                'lr':
                1e-4
            }])
        elif optimizer == 'adam':
            self.optimizer = torch.optim.Adam([{
                'params':
                list(self.bert_model.parameters()),
                'lr':
                3e-5
            }, {
                'params':
                list(self.image_encoder.parameters()) +
                list(self.text_encoder.parameters()) +
                list(self.image_mha.parameters()),
                'lr':
                1e-4
            }])

            # self.optimizer = torch.optim.Adam([{'params':list(self.bert_model.parameters()),'lr':3e-5},
            #                     {'params':list(self.text_encoder.parameters()) + list(self.image_mha.parameters()),'lr':1e-4}])

        if self.use_lr_scheduler:
            self.lr_scheduler = get_linear_schedule_with_warmup(
                self.optimizer,
                num_warmup_steps=100,
                num_training_steps=num_training_steps)
        self.lr_scheduler_0 = get_constant_schedule(self.optimizer)
        # loss
        self.mrl_loss = MarginRankingLoss(margin=margin_loss,
                                          max_violation=max_violation,
                                          cost_style=cost_style,
                                          direction='bidir')

    def forward(self, image_feature, image_attention_mask, input_ids,
                attention_mask, epoch):
        if epoch > 1 and self.frozen:
            self.frozen = False
            del self.lr_scheduler_0
            torch.cuda.empty_cache()

        image_feature = l2norm(image_feature).detach()
        final_image_features = l2norm(
            self.image_mha(image_feature, image_attention_mask))
        text_feature = self.bert_model(input_ids,
                                       attention_mask=attention_mask)
        text_feature = l2norm(text_feature)
        if epoch == 1:
            text_feature = text_feature.detach()
            self.frozen = True
        image_to_common = self.image_encoder(final_image_features)
        # image_to_common = final_image_features
        text_to_common = self.text_encoder(text_feature)
        return image_to_common, text_to_common

    def save_network(self, folder):
        torch.save(self.image_mha.state_dict(),
                   os.path.join(folder, 'image_mha.pth'))
        torch.save(self.text_encoder.state_dict(),
                   os.path.join(folder, 'text_encoder.pth'))
        torch.save(self.image_encoder.state_dict(),
                   os.path.join(folder, 'image_encoder.pth'))
        torch.save(self.bert_model.state_dict(),
                   os.path.join(folder, 'bert_model.pth'))
        torch.save(self.optimizer.state_dict(),
                   os.path.join(folder, 'optimizer.pth'))
        if self.use_lr_scheduler:
            torch.save(self.lr_scheduler.state_dict(),
                       os.path.join(folder, 'scheduler.pth'))

    def switch_to_train(self):
        self.image_mha.train()
        self.text_encoder.train()
        self.image_encoder.train()
        self.bert_model.train()

    def switch_to_eval(self):
        self.image_mha.eval()
        self.text_encoder.eval()
        self.image_encoder.eval()
        self.bert_model.eval()

    def train(self, image_features, image_attention_mask, input_ids,
              attention_mask, epoch):
        self.switch_to_train()
        image_to_common, text_to_common = self.forward(image_features,
                                                       image_attention_mask,
                                                       input_ids,
                                                       attention_mask, epoch)
        self.optimizer.zero_grad()

        # Compute loss
        loss = self.mrl_loss(text_to_common, image_to_common)
        loss.backward()
        if self.grad_clip > 0:
            torch.nn.utils.clip_grad.clip_grad_norm_(self.params,
                                                     self.grad_clip)

        self.optimizer.step()
        return loss.item()

    def step_scheduler(self):
        if self.use_lr_scheduler and not self.frozen:
            self.lr_scheduler.step()
        else:
            self.lr_scheduler_0.step()

    def evaluate(self, val_image_dataloader, val_text_dataloader, k):
        self.switch_to_eval()
        # Load image features
        with torch.no_grad():
            image_features = []
            image_ids = []
            for ids, features, image_attention_mask in val_image_dataloader:
                image_ids.append(torch.stack(ids))
                features = torch.stack(features).to(self.device)
                image_attention_mask = torch.stack(image_attention_mask).to(
                    self.device)
                features = l2norm(features).detach()
                mha_features = l2norm(
                    self.image_mha(features, image_attention_mask))
                image_features.append(self.image_encoder(mha_features))
                # image_features.append(mha_features)
            image_features = torch.cat(image_features, dim=0)
            image_ids = torch.cat(image_ids, dim=0).to(self.device)
            # Evaluate
            recall = 0
            total_query = 0
            pbar = tqdm(enumerate(val_text_dataloader),
                        total=len(val_text_dataloader),
                        leave=False,
                        position=0,
                        file=sys.stdout)
            for i, (image_files, input_ids, attention_mask) in pbar:
                input_ids = input_ids.to(self.device)
                attention_mask = attention_mask.to(self.device)
                text_features = self.bert_model(input_ids,
                                                attention_mask=attention_mask)
                text_features = l2norm(text_features)
                text_features = self.text_encoder(text_features)
                image_files = torch.tensor(
                    list(
                        map(lambda x: int(re.findall(r'\d{12}', x)[0]),
                            image_files))).to(device)
                top_k = get_top_k_eval(text_features, image_features, k)
                for idx, indices in enumerate(top_k):
                    total_query += 1
                    true_image_id = image_files[idx]
                    top_k_images = torch.gather(image_ids, 0, indices)
                    if (top_k_images == true_image_id).nonzero().numel() > 0:
                        recall += 1
            recall = recall / total_query
            return recall
Example #14
0
def main(args):

    workdir = os.path.expanduser(args.training_directory)

    if os.path.exists(workdir) and not args.force:
        print("[error] %s exists, use -f to force continue training." % workdir)
        exit(1)

    init(args.seed, args.device)
    device = torch.device(args.device)

    print("[loading data]")
    chunks, targets, lengths = load_data(limit=args.chunks, shuffle=True, directory=args.directory)

    split = np.floor(chunks.shape[0] * args.validation_split).astype(np.int32)
    train_dataset = ChunkDataSet(chunks[:split], targets[:split], lengths[:split])
    test_dataset = ChunkDataSet(chunks[split:], targets[split:], lengths[split:])
    train_loader = DataLoader(train_dataset, batch_size=args.batch, shuffle=True, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=args.batch, num_workers=4, pin_memory=True)

    config = toml.load(args.config)
    argsdict = dict(training=vars(args))

    chunk_config = {}
    chunk_config_file = os.path.join(args.directory, 'config.toml')
    if os.path.isfile(chunk_config_file):
        chunk_config = toml.load(os.path.join(chunk_config_file))

    os.makedirs(workdir, exist_ok=True)
    toml.dump({**config, **argsdict, **chunk_config}, open(os.path.join(workdir, 'config.toml'), 'w'))

    print("[loading model]")
    model = load_symbol(config, 'Model')(config)
    optimizer = AdamW(model.parameters(), amsgrad=False, lr=args.lr)

    last_epoch = load_state(workdir, args.device, model, optimizer, use_amp=args.amp)

    lr_scheduler = func_scheduler(
        optimizer, cosine_decay_schedule(1.0, 0.1), args.epochs * len(train_loader),
        warmup_steps=500, start_step=last_epoch*len(train_loader)
    )

    if args.multi_gpu:
        from torch.nn import DataParallel
        model = DataParallel(model)
        model.decode = model.module.decode
        model.alphabet = model.module.alphabet

    if hasattr(model, 'seqdist'):
        criterion = model.seqdist.ctc_loss
    else:
        criterion = None

    for epoch in range(1 + last_epoch, args.epochs + 1 + last_epoch):

        try:
            train_loss, duration = train(
                model, device, train_loader, optimizer, criterion=criterion,
                use_amp=args.amp, lr_scheduler=lr_scheduler
            )
            val_loss, val_mean, val_median = test(
                model, device, test_loader, criterion=criterion
            )
        except KeyboardInterrupt:
            break

        print("[epoch {}] directory={} loss={:.4f} mean_acc={:.3f}% median_acc={:.3f}%".format(
            epoch, workdir, val_loss, val_mean, val_median
        ))

        model_state = model.state_dict() if not args.multi_gpu else model.module.state_dict()
        torch.save(model_state, os.path.join(workdir, "weights_%s.tar" % epoch))
        torch.save(optimizer.state_dict(), os.path.join(workdir, "optim_%s.tar" % epoch))

        with open(os.path.join(workdir, 'training.csv'), 'a', newline='') as csvfile:
            csvw = csv.writer(csvfile, delimiter=',')
            if epoch == 1:
                csvw.writerow([
                    'time', 'duration', 'epoch', 'train_loss',
                    'validation_loss', 'validation_mean', 'validation_median'
                ])
            csvw.writerow([
                datetime.today(), int(duration), epoch,
                train_loss, val_loss, val_mean, val_median,
            ])
                                  global_step)
                logger.add_images("test/2_output_outline",
                                  unnormalize(output) * (1 - outline),
                                  global_step)
                # Log these only once
                if first_run:
                    logger.add_images("test/1_target", unnormalize(target),
                                      global_step)
                    logger.add_images("test/3_target_outline",
                                      unnormalize(target) * (1 - outline),
                                      global_step)
                    logger.add_images("test/4_input_morphed",
                                      unnormalize(GMM_morph), global_step)
                    logger.add_images("test/5_input_outline", outline,
                                      global_step)
                    first_run = False

            if global_step % save_interval == 0:
                output_path = Path(
                    f"./training/checkpoints/{run_id}/E{e}_L{loss_G.item()}.pth"
                )
                output_path.parent.mkdir(parents=True, exist_ok=True)
                torch.save(
                    {
                        "G": G.state_dict(),
                        "e": e,
                        "i": i,
                        "run_id": run_id,
                        "optimizer_G": optimizer_G.state_dict()
                    }, output_path)
                print(f"Saved {output_path.stem}.")
Example #16
0
class Seq2seqKpGen(object):
    """High level model that handles intializing the underlying network
    architecture, saving, updating examples, and predicting examples.
    """

    # --------------------------------------------------------------------------
    # Initialization
    # --------------------------------------------------------------------------

    def __init__(self, args, word_dict, state_dict=None):
        # Book-keeping.
        self.args = args
        self.word_dict = word_dict
        self.args.vocab_size = len(word_dict)
        self.updates = 0

        self.network = Sequence2Sequence(self.args, self.word_dict)
        if state_dict:
            self.network.load_state_dict(state_dict)

    def activate_fp16(self):
        if not hasattr(self, 'optimizer'):
            self.network.half()  # for testing only
            return
        try:
            global amp
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        # https://github.com/NVIDIA/apex/issues/227
        assert self.optimizer is not None
        self.network, self.optimizer = amp.initialize(self.network,
                                                      self.optimizer,
                                                      opt_level=self.args.fp16_opt_level)

    def init_optimizer(self, optim_state=None, sched_state=None):
        def get_constant_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch=-1):
            def lr_lambda(current_step: int):
                if current_step < num_warmup_steps:
                    return float(current_step) / float(max(1.0, num_warmup_steps))
                return 1.0

            return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)

        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.network.named_parameters()
                           if not any(nd in n for nd in no_decay)],
                "weight_decay": self.args.weight_decay,
            },
            {"params": [p for n, p in self.network.named_parameters()
                        if any(nd in n for nd in no_decay)],
             "weight_decay": 0.0},
        ]

        self.optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate)
        self.scheduler = get_constant_schedule_with_warmup(self.optimizer, self.args.warmup_steps)

        if optim_state:
            self.optimizer.load_state_dict(optim_state)
            if self.args.device.type == 'cuda':
                for state in self.optimizer.state.values():
                    for k, v in state.items():
                        if torch.is_tensor(v):
                            state[k] = v.to(self.args.device)
        if sched_state:
            self.scheduler.load_state_dict(sched_state)

    # --------------------------------------------------------------------------
    # Learning
    # --------------------------------------------------------------------------

    def update(self, ex):
        """Forward a batch of examples; step the optimizer to update weights."""
        if not self.optimizer:
            raise RuntimeError('No optimizer set.')

        # Train mode
        self.network.train()

        source_map, alignment = None, None
        if self.args.copy_attn:
            source_map = make_src_map(ex['src_map']).to(self.args.device)
            alignment = align(ex['alignment']).to(self.args.device)

        source_rep = ex['source_rep'].to(self.args.device)
        source_len = ex['source_len'].to(self.args.device)
        target_rep = ex['target_rep'].to(self.args.device)
        target_len = ex['target_len'].to(self.args.device)

        # Run forward
        ml_loss, loss_per_token = self.network(source=source_rep,
                                               source_len=source_len,
                                               target=target_rep,
                                               target_len=target_len,
                                               src_map=source_map,
                                               alignment=alignment)

        loss = ml_loss.mean() if self.args.n_gpu > 1 else ml_loss
        if self.args.fp16:
            global amp
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
            clip_grad_norm_(amp.master_params(self.optimizer), self.args.grad_clipping)
        else:
            loss.backward()
            clip_grad_norm_(self.network.parameters(), self.args.grad_clipping)

        self.updates += 1
        self.optimizer.step()
        self.scheduler.step()  # Update learning rate schedule
        self.optimizer.zero_grad()

        loss_per_token = loss_per_token.mean() if self.args.n_gpu > 1 else loss_per_token
        loss_per_token = loss_per_token.item()
        loss_per_token = 10 if loss_per_token > 10 else loss_per_token
        perplexity = math.exp(loss_per_token)

        return {
            'ml_loss': loss.item(),
            'perplexity': perplexity
        }

    # --------------------------------------------------------------------------
    # Prediction
    # --------------------------------------------------------------------------

    def predict(self, ex, replace_unk=False):
        """Forward a batch of examples only to get predictions.
        Args:
            ex: the batch examples
            replace_unk: replace `unk` tokens while generating predictions
            src_raw: raw source (passage); required to replace `unk` term
        Output:
            predictions: #batch predicted sequences
        """

        def convert_text_to_string(text):
            """ Converts a sequence of tokens (string) in a single string. """
            out_string = text.replace(" ##", "").strip()
            return out_string

        self.network.eval()

        source_map, alignment = None, None
        blank, fill = None, None
        if self.args.copy_attn:
            source_map = make_src_map(ex['src_map']).to(self.args.device)
            alignment = align(ex['alignment']).to(self.args.device)
            blank, fill = collapse_copy_scores(self.word_dict, ex['src_vocab'])

        source_rep = ex['source_rep'].to(self.args.device)
        source_len = ex['source_len'].to(self.args.device)

        decoder_out = self.network(source=source_rep,
                                   source_len=source_len,
                                   target=None,
                                   target_len=None,
                                   src_map=source_map,
                                   alignment=alignment,
                                   max_len=self.args.max_tgt_len,
                                   tgt_dict=self.word_dict,
                                   blank=blank, fill=fill,
                                   source_vocab=ex['src_vocab'])

        dec_probs = torch.exp(decoder_out['dec_log_probs'])
        predictions, scores = tens2sen_score(decoder_out['predictions'], dec_probs,
                                             self.word_dict, ex['src_vocab'])
        if replace_unk:
            for i in range(len(predictions)):
                enc_dec_attn = decoder_out['attentions'][i]
                if self.args.model_type == 'transformer':
                    # tgt_len x num_heads x src_len
                    assert enc_dec_attn.dim() == 3
                    enc_dec_attn = enc_dec_attn.mean(1)
                predictions[i] = replace_unknown(predictions[i], enc_dec_attn,
                                                 src_raw=ex['source'][i].tokens)

        for bidx in range(ex['batch_size']):
            for i in range(len(predictions[bidx])):
                if predictions[bidx][i] == constants.KP_SEP:
                    scores[bidx][i] = constants.KP_SEP
                elif predictions[bidx][i] == constants.PRESENT_EOS:
                    scores[bidx][i] = constants.PRESENT_EOS
                else:
                    assert isinstance(scores[bidx][i], float)
                    scores[bidx][i] = str(scores[bidx][i])

        predictions = [' '.join(item) for item in predictions]
        scores = [' '.join(item) for item in scores]

        present_kps = []
        absent_kps = []
        present_kp_scores = []
        absent_kp_scores = []
        for bidx in range(ex['batch_size']):
            keyphrases = predictions[bidx].split(constants.PRESENT_EOS)
            kp_scores = scores[bidx].split(constants.PRESENT_EOS)
            pkps = (' %s ' % constants.KP_SEP).join(keyphrases[:-1])
            pkp_scores = (' %s ' % constants.KP_SEP).join(kp_scores[:-1])
            akps = keyphrases[-1]
            akp_scores = kp_scores[-1]

            pre_kps = []
            pre_kp_scores = []
            for pkp, pkp_s in zip(pkps.split(constants.KP_SEP),
                                  pkp_scores.split(constants.KP_SEP)):
                pkp = pkp.strip()
                if pkp:
                    pre_kps.append(convert_text_to_string(pkp))
                    t_scores = [float(i) for i in pkp_s.strip().split()]
                    _score = np.prod(t_scores) / len(t_scores)
                    pre_kp_scores.append(_score)

            present_kps.append(pre_kps)
            present_kp_scores.append(pre_kp_scores)

            abs_kps = []
            abs_kp_scores = []
            for akp, akp_s in zip(akps.split(constants.KP_SEP),
                                  akp_scores.split(constants.KP_SEP)):
                akp = akp.strip()
                if akp:
                    abs_kps.append(convert_text_to_string(akp))
                    t_scores = [float(i) for i in akp_s.strip().split()]
                    _score = np.prod(t_scores) / len(t_scores)
                    abs_kp_scores.append(_score)

            absent_kps.append(abs_kps)
            absent_kp_scores.append(abs_kp_scores)

        return {
            'present_kps': present_kps,
            'absent_kps': absent_kps,
            'present_kp_scores': present_kp_scores,
            'absent_kp_scores': absent_kp_scores
        }

    # --------------------------------------------------------------------------
    # Saving and loading
    # --------------------------------------------------------------------------

    def save(self, filename):
        network = self.network.module if hasattr(self.network, "module") \
            else self.network
        state_dict = copy.copy(network.state_dict())
        params = {
            'state_dict': state_dict,
            'word_dict': self.word_dict,
            'args': self.args,
        }
        try:
            torch.save(params, filename)
        except BaseException:
            logger.warning('WARN: Saving failed... continuing anyway.')

    def checkpoint(self, filename, epoch):
        network = self.network.module if hasattr(self.network, "module") \
            else self.network
        params = {
            'state_dict': network.state_dict(),
            'word_dict': self.word_dict,
            'args': self.args,
            'epoch': epoch,
            'updates': self.updates,
            'optim_dict': self.optimizer.state_dict(),
            'sched_dict': self.scheduler.state_dict(),
        }
        try:
            torch.save(params, filename)
        except BaseException:
            logger.warning('WARN: Saving failed... continuing anyway.')

    @staticmethod
    def load(filename, new_args=None):
        logger.info('Loading model %s' % filename)
        saved_params = torch.load(
            filename, map_location=lambda storage, loc: storage
        )
        word_dict = saved_params['word_dict']
        state_dict = saved_params['state_dict']
        args = saved_params['args']
        if new_args:
            args = override_model_args(args, new_args)
        return Seq2seqKpGen(args, word_dict, state_dict)

    @staticmethod
    def load_checkpoint(filename):
        logger.info('Loading model %s' % filename)
        saved_params = torch.load(
            filename, map_location=lambda storage, loc: storage
        )
        word_dict = saved_params['word_dict']
        state_dict = saved_params['state_dict']
        epoch = saved_params['epoch']
        updates = saved_params['updates']
        optim_dict = saved_params['optim_dict']
        sched_dict = saved_params['sched_dict']
        args = saved_params['args']
        model = Seq2seqKpGen(args, word_dict, state_dict)
        model.updates = updates
        model.init_optimizer(optim_dict, sched_dict)
        return model, epoch

    # --------------------------------------------------------------------------
    # Runtime
    # --------------------------------------------------------------------------

    def to(self, device):
        self.network = self.network.to(device)

    def parallelize(self):
        self.network = torch.nn.DataParallel(self.network)
Example #17
0
class Trainer():
    def __init__(self, config, pretrained=True, augmentor=ImgAugTransform()):

        self.config = config
        self.model, self.vocab = build_model(config)

        self.device = config['device']
        self.num_iters = config['trainer']['iters']
        self.beamsearch = config['predictor']['beamsearch']

        self.data_root = config['dataset']['data_root']
        self.train_annotation = config['dataset']['train_annotation']
        self.valid_annotation = config['dataset']['valid_annotation']
        self.train_lmdb = config['dataset']['train_lmdb']
        self.valid_lmdb = config['dataset']['valid_lmdb']
        self.dataset_name = config['dataset']['name']

        self.batch_size = config['trainer']['batch_size']
        self.print_every = config['trainer']['print_every']
        self.valid_every = config['trainer']['valid_every']

        self.image_aug = config['aug']['image_aug']
        self.masked_language_model = config['aug']['masked_language_model']
        self.metrics = config['trainer']['metrics']
        self.is_padding = config['dataset']['is_padding']

        self.tensorboard_dir = config['monitor']['log_dir']
        if not os.path.exists(self.tensorboard_dir):
            os.makedirs(self.tensorboard_dir, exist_ok=True)
        self.writer = SummaryWriter(self.tensorboard_dir)

        # LOGGER
        self.logger = Logger(config['monitor']['log_dir'])
        self.logger.info(config)

        self.iter = 0
        self.best_acc = 0
        self.scheduler = None
        self.is_finetuning = config['trainer']['is_finetuning']

        if self.is_finetuning:
            self.logger.info("Finetuning model ---->")
            if self.model.seq_modeling == 'crnn':
                self.optimizer = Adam(lr=0.0001,
                                      params=self.model.parameters(),
                                      betas=(0.5, 0.999))
            else:
                self.optimizer = AdamW(lr=0.0001,
                                       params=self.model.parameters(),
                                       betas=(0.9, 0.98),
                                       eps=1e-09)

        else:

            self.optimizer = AdamW(self.model.parameters(),
                                   betas=(0.9, 0.98),
                                   eps=1e-09)
            self.scheduler = OneCycleLR(self.optimizer,
                                        total_steps=self.num_iters,
                                        **config['optimizer'])

        if self.model.seq_modeling == 'crnn':
            self.criterion = torch.nn.CTCLoss(self.vocab.pad,
                                              zero_infinity=True)
        else:
            self.criterion = LabelSmoothingLoss(len(self.vocab),
                                                padding_idx=self.vocab.pad,
                                                smoothing=0.1)

        # Pretrained model
        if config['trainer']['pretrained']:
            self.load_weights(config['trainer']['pretrained'])
            self.logger.info("Loaded trained model from: {}".format(
                config['trainer']['pretrained']))

        # Resume
        elif config['trainer']['resume_from']:
            self.load_checkpoint(config['trainer']['resume_from'])
            for state in self.optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = v.to(torch.device(self.device))

            self.logger.info("Resume training from {}".format(
                config['trainer']['resume_from']))

        # DATASET
        transforms = None
        if self.image_aug:
            transforms = augmentor

        train_lmdb_paths = [
            os.path.join(self.data_root, lmdb_path)
            for lmdb_path in self.train_lmdb
        ]

        self.train_gen = self.data_gen(
            lmdb_paths=train_lmdb_paths,
            data_root=self.data_root,
            annotation=self.train_annotation,
            masked_language_model=self.masked_language_model,
            transform=transforms,
            is_train=True)

        if self.valid_annotation:
            self.valid_gen = self.data_gen(
                lmdb_paths=[os.path.join(self.data_root, self.valid_lmdb)],
                data_root=self.data_root,
                annotation=self.valid_annotation,
                masked_language_model=False)

        self.train_losses = []
        self.logger.info("Number batch samples of training: %d" %
                         len(self.train_gen))
        self.logger.info("Number batch samples of valid: %d" %
                         len(self.valid_gen))

        config_savepath = os.path.join(self.tensorboard_dir, "config.yml")
        if not os.path.exists(config_savepath):
            self.logger.info("Saving config file at: %s" % config_savepath)
            Cfg(config).save(config_savepath)

    def train(self):
        total_loss = 0

        total_loader_time = 0
        total_gpu_time = 0
        data_iter = iter(self.train_gen)
        for i in range(self.num_iters):
            self.iter += 1
            start = time.time()

            try:
                batch = next(data_iter)
            except StopIteration:
                data_iter = iter(self.train_gen)
                batch = next(data_iter)

            total_loader_time += time.time() - start
            start = time.time()

            # LOSS
            loss = self.step(batch)
            total_loss += loss
            self.train_losses.append((self.iter, loss))

            total_gpu_time += time.time() - start

            if self.iter % self.print_every == 0:

                info = 'Iter: {:06d} - Train loss: {:.3f} - lr: {:.2e} - load time: {:.2f} - gpu time: {:.2f}'.format(
                    self.iter, total_loss / self.print_every,
                    self.optimizer.param_groups[0]['lr'], total_loader_time,
                    total_gpu_time)
                lastest_loss = total_loss / self.print_every
                total_loss = 0
                total_loader_time = 0
                total_gpu_time = 0
                self.logger.info(info)

            if self.valid_annotation and self.iter % self.valid_every == 0:
                val_time = time.time()
                val_loss = self.validate()
                acc_full_seq, acc_per_char, wer = self.precision(self.metrics)

                self.logger.info("Iter: {:06d}, start validating".format(
                    self.iter))
                info = 'Iter: {:06d} - Valid loss: {:.3f} - Acc full seq: {:.4f} - Acc per char: {:.4f} - WER: {:.4f} - Time: {:.4f}'.format(
                    self.iter, val_loss, acc_full_seq, acc_per_char, wer,
                    time.time() - val_time)
                self.logger.info(info)

                if acc_full_seq > self.best_acc:
                    self.save_weights(self.tensorboard_dir + "/best.pt")
                    self.best_acc = acc_full_seq

                self.logger.info("Iter: {:06d} - Best acc: {:.4f}".format(
                    self.iter, self.best_acc))

                filename = 'last.pt'
                filepath = os.path.join(self.tensorboard_dir, filename)
                self.logger.info("Save checkpoint %s" % filename)
                self.save_checkpoint(filepath)

                log_loss = {'train loss': lastest_loss, 'val loss': val_loss}
                self.writer.add_scalars('Loss', log_loss, self.iter)
                self.writer.add_scalar('WER', wer, self.iter)

    def validate(self):
        self.model.eval()

        total_loss = []

        with torch.no_grad():
            for step, batch in enumerate(self.valid_gen):
                batch = self.batch_to_device(batch)
                img, tgt_input, tgt_output, tgt_padding_mask = batch[
                    'img'], batch['tgt_input'], batch['tgt_output'], batch[
                        'tgt_padding_mask']

                outputs = self.model(img, tgt_input, tgt_padding_mask)
                #                loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)'))

                if self.model.seq_modeling == 'crnn':
                    length = batch['labels_len']
                    preds_size = torch.autograd.Variable(
                        torch.IntTensor([outputs.size(0)] * self.batch_size))
                    loss = self.criterion(outputs, tgt_output, preds_size,
                                          length)
                else:
                    outputs = outputs.flatten(0, 1)
                    tgt_output = tgt_output.flatten()
                    loss = self.criterion(outputs, tgt_output)

                total_loss.append(loss.item())

                del outputs
                del loss

        total_loss = np.mean(total_loss)
        self.model.train()

        return total_loss

    def predict(self, sample=None):
        pred_sents = []
        actual_sents = []
        img_files = []
        probs_sents = []
        imgs_sents = []

        for idx, batch in enumerate(tqdm.tqdm(self.valid_gen)):
            batch = self.batch_to_device(batch)

            if self.model.seq_modeling != 'crnn':
                if self.beamsearch:
                    translated_sentence = batch_translate_beam_search(
                        batch['img'], self.model)
                    prob = None
                else:
                    translated_sentence, prob = translate(
                        batch['img'], self.model)
                pred_sent = self.vocab.batch_decode(
                    translated_sentence.tolist())
            else:
                translated_sentence, prob = translate_crnn(
                    batch['img'], self.model)
                pred_sent = self.vocab.batch_decode(
                    translated_sentence.tolist(), crnn=True)

            actual_sent = self.vocab.batch_decode(batch['tgt_output'].tolist())
            pred_sents.extend(pred_sent)
            actual_sents.extend(actual_sent)

            imgs_sents.extend(batch['img'])
            img_files.extend(batch['filenames'])
            probs_sents.extend(prob)

            # Visualize in tensorboard
            if idx == 0:
                try:
                    num_samples = self.config['monitor']['num_samples']
                    fig = plt.figure(figsize=(12, 15))
                    imgs_samples = imgs_sents[:num_samples]
                    preds_samples = pred_sents[:num_samples]
                    actuals_samples = actual_sents[:num_samples]
                    probs_samples = probs_sents[:num_samples]
                    for id_img in range(len(imgs_samples)):
                        img = imgs_samples[id_img]
                        img = img.permute(1, 2, 0)
                        img = img.cpu().detach().numpy()
                        ax = fig.add_subplot(num_samples,
                                             1,
                                             id_img + 1,
                                             xticks=[],
                                             yticks=[])
                        plt.imshow(img)
                        ax.set_title(
                            "LB: {} \n Pred: {:.4f}-{}".format(
                                actuals_samples[id_img], probs_samples[id_img],
                                preds_samples[id_img]),
                            color=('green' if actuals_samples[id_img]
                                   == preds_samples[id_img] else 'red'),
                            fontdict={
                                'fontsize': 18,
                                'fontweight': 'medium'
                            })

                    self.writer.add_figure('predictions vs. actuals',
                                           fig,
                                           global_step=self.iter)
                except Exception as error:
                    print(error)
                    continue

            if sample != None and len(pred_sents) > sample:
                break

        return pred_sents, actual_sents, img_files, probs_sents, imgs_sents

    def precision(self, sample=None, measure_time=True):
        t1 = time.time()
        pred_sents, actual_sents, _, _, _ = self.predict(sample=sample)
        time_predict = time.time() - t1

        sensitive_case = self.config['predictor']['sensitive_case']
        acc_full_seq = compute_accuracy(actual_sents,
                                        pred_sents,
                                        sensitive_case,
                                        mode='full_sequence')
        acc_per_char = compute_accuracy(actual_sents,
                                        pred_sents,
                                        sensitive_case,
                                        mode='per_char')
        wer = compute_accuracy(actual_sents,
                               pred_sents,
                               sensitive_case,
                               mode='wer')

        if measure_time:
            print("Time: {:.4f}".format(time_predict / len(actual_sents)))
        return acc_full_seq, acc_per_char, wer

    def visualize_prediction(self,
                             sample=16,
                             errorcase=False,
                             fontname='serif',
                             fontsize=16,
                             save_fig=False):

        pred_sents, actual_sents, img_files, probs, imgs = self.predict(sample)

        if errorcase:
            wrongs = []
            for i in range(len(img_files)):
                if pred_sents[i] != actual_sents[i]:
                    wrongs.append(i)

            pred_sents = [pred_sents[i] for i in wrongs]
            actual_sents = [actual_sents[i] for i in wrongs]
            img_files = [img_files[i] for i in wrongs]
            probs = [probs[i] for i in wrongs]
            imgs = [imgs[i] for i in wrongs]

        img_files = img_files[:sample]

        fontdict = {'family': fontname, 'size': fontsize}
        ncols = 5
        nrows = int(math.ceil(len(img_files) / ncols))
        fig, ax = plt.subplots(nrows, ncols, figsize=(12, 15))

        for vis_idx in range(0, len(img_files)):
            row = vis_idx // ncols
            col = vis_idx % ncols

            pred_sent = pred_sents[vis_idx]
            actual_sent = actual_sents[vis_idx]
            prob = probs[vis_idx]
            img = imgs[vis_idx].permute(1, 2, 0).cpu().detach().numpy()

            ax[row, col].imshow(img)
            ax[row, col].set_title(
                "Pred: {: <2} \n Actual: {} \n prob: {:.2f}".format(
                    pred_sent, actual_sent, prob),
                fontname=fontname,
                color='r' if pred_sent != actual_sent else 'g')
            ax[row, col].get_xaxis().set_ticks([])
            ax[row, col].get_yaxis().set_ticks([])

        plt.subplots_adjust()
        if save_fig:
            fig.savefig('vis_prediction.png')
        plt.show()

    def log_prediction(self, sample=16, csv_file='model.csv'):
        pred_sents, actual_sents, img_files, probs, imgs = self.predict(sample)
        save_predictions(csv_file, pred_sents, actual_sents, img_files)

    def vis_data(self, sample=20):

        ncols = 5
        nrows = int(math.ceil(sample / ncols))
        fig, ax = plt.subplots(nrows, ncols, figsize=(12, 12))

        num_plots = 0
        for idx, batch in enumerate(self.train_gen):
            for vis_idx in range(self.batch_size):
                row = num_plots // ncols
                col = num_plots % ncols

                img = batch['img'][vis_idx].numpy().transpose(1, 2, 0)
                sent = self.vocab.decode(
                    batch['tgt_input'].T[vis_idx].tolist())

                ax[row, col].imshow(img)
                ax[row, col].set_title("Label: {: <2}".format(sent),
                                       fontsize=16,
                                       color='g')

                ax[row, col].get_xaxis().set_ticks([])
                ax[row, col].get_yaxis().set_ticks([])

                num_plots += 1
                if num_plots >= sample:
                    plt.subplots_adjust()
                    fig.savefig('vis_dataset.png')
                    return

    def load_checkpoint(self, filename):
        checkpoint = torch.load(filename)

        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.model.load_state_dict(checkpoint['state_dict'])
        self.iter = checkpoint['iter']
        self.train_losses = checkpoint['train_losses']
        if self.scheduler is not None:
            self.scheduler.load_state_dict(checkpoint['scheduler'])

        self.best_acc = checkpoint['best_acc']

    def save_checkpoint(self, filename):
        state = {
            'iter':
            self.iter,
            'state_dict':
            self.model.state_dict(),
            'optimizer':
            self.optimizer.state_dict(),
            'train_losses':
            self.train_losses,
            'scheduler':
            None if self.scheduler is None else self.scheduler.state_dict(),
            'best_acc':
            self.best_acc
        }

        path, _ = os.path.split(filename)
        os.makedirs(path, exist_ok=True)

        torch.save(state, filename)

    def load_weights(self, filename):
        state_dict = torch.load(filename,
                                map_location=torch.device(self.device))
        if self.is_checkpoint(state_dict):
            self.model.load_state_dict(state_dict['state_dict'])
        else:

            for name, param in self.model.named_parameters():
                if name not in state_dict:
                    print('{} not found'.format(name))
                elif state_dict[name].shape != param.shape:
                    print('{} missmatching shape, required {} but found {}'.
                          format(name, param.shape, state_dict[name].shape))
                    del state_dict[name]
            self.model.load_state_dict(state_dict, strict=False)

    def save_weights(self, filename):
        path, _ = os.path.split(filename)
        os.makedirs(path, exist_ok=True)

        torch.save(self.model.state_dict(), filename)

    def is_checkpoint(self, checkpoint):
        try:
            checkpoint['state_dict']
        except:
            return False
        else:
            return True

    def batch_to_device(self, batch):
        img = batch['img'].to(self.device, non_blocking=True)
        tgt_input = batch['tgt_input'].to(self.device, non_blocking=True)
        tgt_output = batch['tgt_output'].to(self.device, non_blocking=True)
        tgt_padding_mask = batch['tgt_padding_mask'].to(self.device,
                                                        non_blocking=True)

        batch = {
            'img': img,
            'tgt_input': tgt_input,
            'tgt_output': tgt_output,
            'tgt_padding_mask': tgt_padding_mask,
            'filenames': batch['filenames'],
            'labels_len': batch['labels_len']
        }

        return batch

    def data_gen(self,
                 lmdb_paths,
                 data_root,
                 annotation,
                 masked_language_model=True,
                 transform=None,
                 is_train=False):
        datasets = []
        for lmdb_path in lmdb_paths:
            dataset = OCRDataset(
                lmdb_path=lmdb_path,
                root_dir=data_root,
                annotation_path=annotation,
                vocab=self.vocab,
                transform=transform,
                image_height=self.config['dataset']['image_height'],
                image_min_width=self.config['dataset']['image_min_width'],
                image_max_width=self.config['dataset']['image_max_width'],
                separate=self.config['dataset']['separate'],
                batch_size=self.batch_size,
                is_padding=self.is_padding)
            datasets.append(dataset)
        if len(self.train_lmdb) > 1:
            dataset = torch.utils.data.ConcatDataset(datasets)

        if self.is_padding:
            sampler = None
        else:
            sampler = ClusterRandomSampler(dataset, self.batch_size, True)

        collate_fn = Collator(masked_language_model)

        gen = DataLoader(dataset,
                         batch_size=self.batch_size,
                         sampler=sampler,
                         collate_fn=collate_fn,
                         shuffle=is_train,
                         drop_last=self.model.seq_modeling == 'crnn',
                         **self.config['dataloader'])

        return gen

    def step(self, batch):
        self.model.train()

        batch = self.batch_to_device(batch)
        img, tgt_input, tgt_output, tgt_padding_mask = batch['img'], batch[
            'tgt_input'], batch['tgt_output'], batch['tgt_padding_mask']

        outputs = self.model(img,
                             tgt_input,
                             tgt_key_padding_mask=tgt_padding_mask)
        #        loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)'))

        if self.model.seq_modeling == 'crnn':
            length = batch['labels_len']
            preds_size = torch.autograd.Variable(
                torch.IntTensor([outputs.size(0)] * self.batch_size))
            loss = self.criterion(outputs, tgt_output, preds_size, length)
        else:
            outputs = outputs.view(
                -1, outputs.size(2))  # flatten(0, 1)    # B*S x N_class
            tgt_output = tgt_output.view(-1)  # flatten()    # B*S
            loss = self.criterion(outputs, tgt_output)

        self.optimizer.zero_grad()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)

        self.optimizer.step()

        if not self.is_finetuning:
            self.scheduler.step()

        loss_item = loss.item()

        return loss_item

    def count_parameters(self, model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    def gen_pseudo_labels(self, outfile=None):
        pred_sents = []
        img_files = []
        probs_sents = []

        for idx, batch in enumerate(tqdm.tqdm(self.valid_gen)):
            batch = self.batch_to_device(batch)

            if self.model.seq_modeling != 'crnn':
                if self.beamsearch:
                    translated_sentence = batch_translate_beam_search(
                        batch['img'], self.model)
                    prob = None
                else:
                    translated_sentence, prob = translate(
                        batch['img'], self.model)
                pred_sent = self.vocab.batch_decode(
                    translated_sentence.tolist())
            else:
                translated_sentence, prob = translate_crnn(
                    batch['img'], self.model)
                pred_sent = self.vocab.batch_decode(
                    translated_sentence.tolist(), crnn=True)

            pred_sents.extend(pred_sent)
            img_files.extend(batch['filenames'])
            probs_sents.extend(prob)
        assert len(pred_sents) == len(img_files) and len(img_files) == len(
            probs_sents)
        with open(outfile, 'w', encoding='utf-8') as f:
            for anno in zip(img_files, pred_sents, probs_sents):
                f.write('||||'.join([anno[0], anno[1],
                                     str(float(anno[2]))]) + '\n')
Example #18
0
class Trainer():
    def __init__(self, config, pretrained=True):

        self.config = config
        self.model, self.vocab = build_model(config)

        self.device = config['device']
        self.num_iters = config['trainer']['iters']
        self.beamsearch = config['predictor']['beamsearch']

        self.data_root = config['dataset']['data_root']
        self.train_annotation = config['dataset']['train_annotation']
        self.valid_annotation = config['dataset']['valid_annotation']
        self.dataset_name = config['dataset']['name']

        self.batch_size = config['trainer']['batch_size']
        self.print_every = config['trainer']['print_every']
        self.valid_every = config['trainer']['valid_every']

        self.checkpoint = config['trainer']['checkpoint']
        self.export_weights = config['trainer']['export']
        self.metrics = config['trainer']['metrics']
        logger = config['trainer']['log']

        if logger:
            self.logger = Logger(logger)

        if pretrained:
            weight_file = download_weights(**config['pretrain'],
                                           quiet=config['quiet'])
            self.load_weights(weight_file)

        self.iter = 0

        self.optimizer = AdamW(self.model.parameters(),
                               betas=(0.9, 0.98),
                               eps=1e-09)
        self.scheduler = OneCycleLR(self.optimizer, **config['optimizer'])
        #        self.optimizer = ScheduledOptim(
        #            Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09),
        #            #config['transformer']['d_model'],
        #            512,
        #            **config['optimizer'])

        self.criterion = LabelSmoothingLoss(len(self.vocab),
                                            padding_idx=self.vocab.pad,
                                            smoothing=0.1)

        transforms = ImgAugTransform()

        self.train_gen = self.data_gen('train_{}'.format(self.dataset_name),
                                       self.data_root,
                                       self.train_annotation,
                                       transform=transforms)
        if self.valid_annotation:
            self.valid_gen = self.data_gen(
                'valid_{}'.format(self.dataset_name), self.data_root,
                self.valid_annotation)

        self.train_losses = []

    def train(self):
        total_loss = 0

        total_loader_time = 0
        total_gpu_time = 0
        best_acc = 0

        data_iter = iter(self.train_gen)
        for i in range(self.num_iters):
            self.iter += 1

            start = time.time()

            try:
                batch = next(data_iter)
            except StopIteration:
                data_iter = iter(self.train_gen)
                batch = next(data_iter)

            total_loader_time += time.time() - start

            start = time.time()
            loss = self.step(batch)
            total_gpu_time += time.time() - start

            total_loss += loss
            self.train_losses.append((self.iter, loss))

            if self.iter % self.print_every == 0:
                info = 'iter: {:06d} - train loss: {:.3f} - lr: {:.2e} - load time: {:.2f} - gpu time: {:.2f}'.format(
                    self.iter, total_loss / self.print_every,
                    self.optimizer.param_groups[0]['lr'], total_loader_time,
                    total_gpu_time)

                total_loss = 0
                total_loader_time = 0
                total_gpu_time = 0
                print(info)
                self.logger.log(info)

            if self.valid_annotation and self.iter % self.valid_every == 0:
                val_loss = self.validate()
                acc_full_seq, acc_per_char = self.precision(self.metrics)

                info = 'iter: {:06d} - valid loss: {:.3f} - acc full seq: {:.4f} - acc per char: {:.4f}'.format(
                    self.iter, val_loss, acc_full_seq, acc_per_char)
                print(info)
                self.logger.log(info)

                if acc_full_seq > best_acc:
                    self.save_weights(self.export_weights)
                    best_acc = acc_full_seq

    def validate(self):
        self.model.eval()

        total_loss = []

        with torch.no_grad():
            for step, batch in enumerate(self.valid_gen):
                batch = self.batch_to_device(batch)
                img, tgt_input, tgt_output, tgt_padding_mask = batch[
                    'img'], batch['tgt_input'], batch['tgt_output'], batch[
                        'tgt_padding_mask']

                outputs = self.model(img, tgt_input, tgt_padding_mask)
                #                loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)'))

                outputs = outputs.flatten(0, 1)
                tgt_output = tgt_output.flatten()
                loss = self.criterion(outputs, tgt_output)

                total_loss.append(loss.item())

                del outputs
                del loss

        total_loss = np.mean(total_loss)
        self.model.train()

        return total_loss

    def predict(self, sample=None):
        pred_sents = []
        actual_sents = []
        img_files = []

        for batch in self.valid_gen:
            batch = self.batch_to_device(batch)

            if self.beamsearch:
                translated_sentence = batch_translate_beam_search(
                    batch['img'], self.model)
            else:
                translated_sentence = translate(batch['img'], self.model)

            pred_sent = self.vocab.batch_decode(translated_sentence.tolist())
            actual_sent = self.vocab.batch_decode(batch['tgt_output'].tolist())

            img_files.extend(batch['filenames'])

            pred_sents.extend(pred_sent)
            actual_sents.extend(actual_sent)

            if sample != None and len(pred_sents) > sample:
                break

        return pred_sents, actual_sents, img_files

    def precision(self, sample=None):

        pred_sents, actual_sents, _ = self.predict(sample=sample)

        acc_full_seq = compute_accuracy(actual_sents,
                                        pred_sents,
                                        mode='full_sequence')
        acc_per_char = compute_accuracy(actual_sents,
                                        pred_sents,
                                        mode='per_char')

        return acc_full_seq, acc_per_char

    def visualize_prediction(self,
                             sample=16,
                             errorcase=False,
                             fontname='serif',
                             fontsize=16):

        pred_sents, actual_sents, img_files = self.predict(sample)

        if errorcase:
            wrongs = []
            for i in range(len(img_files)):
                if pred_sents[i] != actual_sents[i]:
                    wrongs.append(i)

            pred_sents = [pred_sents[i] for i in wrongs]
            actual_sents = [actual_sents[i] for i in wrongs]
            img_files = [img_files[i] for i in wrongs]

        img_files = img_files[:sample]

        fontdict = {'family': fontname, 'size': fontsize}

        for vis_idx in range(0, len(img_files)):
            img_path = img_files[vis_idx]
            pred_sent = pred_sents[vis_idx]
            actual_sent = actual_sents[vis_idx]

            img = Image.open(open(img_path, 'rb'))
            plt.figure()
            plt.imshow(img)
            plt.title('pred: {} - actual: {}'.format(pred_sent, actual_sent),
                      loc='left',
                      fontdict=fontdict)
            plt.axis('off')

        plt.show()

    def visualize_dataset(self, sample=16, fontname='serif'):
        n = 0
        for batch in self.train_gen:
            for i in range(self.batch_size):
                img = batch['img'][i].numpy().transpose(1, 2, 0)
                sent = self.vocab.decode(batch['tgt_input'].T[i].tolist())

                plt.figure()
                plt.title('sent: {}'.format(sent),
                          loc='center',
                          fontname=fontname)
                plt.imshow(img)
                plt.axis('off')

                n += 1
                if n >= sample:
                    plt.show()
                    return

    def load_checkpoint(self, filename):
        checkpoint = torch.load(filename)

        optim = ScheduledOptim(
            Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09),
            self.config['transformer']['d_model'], **self.config['optimizer'])

        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.model.load_state_dict(checkpoint['state_dict'])
        self.iter = checkpoint['iter']

        self.train_losses = checkpoint['train_losses']

    def save_checkpoint(self, filename):
        state = {
            'iter': self.iter,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'train_losses': self.train_losses
        }

        path, _ = os.path.split(filename)
        os.makedirs(path, exist_ok=True)

        torch.save(state, filename)

    def load_weights(self, filename):
        state_dict = torch.load(filename,
                                map_location=torch.device(self.device))

        for name, param in self.model.named_parameters():
            if name not in state_dict:
                print('{} not found'.format(name))
            elif state_dict[name].shape != param.shape:
                print('{} missmatching shape'.format(name))
                del state_dict[name]

        self.model.load_state_dict(state_dict, strict=False)

    def save_weights(self, filename):
        path, _ = os.path.split(filename)
        os.makedirs(path, exist_ok=True)

        torch.save(self.model.state_dict(), filename)

    def batch_to_device(self, batch):
        img = batch['img'].to(self.device, non_blocking=True)
        tgt_input = batch['tgt_input'].to(self.device, non_blocking=True)
        tgt_output = batch['tgt_output'].to(self.device, non_blocking=True)
        tgt_padding_mask = batch['tgt_padding_mask'].to(self.device,
                                                        non_blocking=True)

        batch = {
            'img': img,
            'tgt_input': tgt_input,
            'tgt_output': tgt_output,
            'tgt_padding_mask': tgt_padding_mask,
            'filenames': batch['filenames']
        }

        return batch

    def data_gen(self, lmdb_path, data_root, annotation, transform=None):
        dataset = OCRDataset(
            lmdb_path=lmdb_path,
            root_dir=data_root,
            annotation_path=annotation,
            vocab=self.vocab,
            transform=transform,
            image_height=self.config['dataset']['image_height'],
            image_min_width=self.config['dataset']['image_min_width'],
            image_max_width=self.config['dataset']['image_max_width'])

        sampler = ClusterRandomSampler(dataset, self.batch_size, True)
        gen = DataLoader(dataset,
                         batch_size=self.batch_size,
                         sampler=sampler,
                         collate_fn=collate_fn,
                         shuffle=False,
                         drop_last=False,
                         **self.config['dataloader'])

        return gen

    def data_gen_v1(self, lmdb_path, data_root, annotation):
        data_gen = DataGen(
            data_root,
            annotation,
            self.vocab,
            'cpu',
            image_height=self.config['dataset']['image_height'],
            image_min_width=self.config['dataset']['image_min_width'],
            image_max_width=self.config['dataset']['image_max_width'])

        return data_gen

    def step(self, batch):
        self.model.train()

        batch = self.batch_to_device(batch)
        img, tgt_input, tgt_output, tgt_padding_mask = batch['img'], batch[
            'tgt_input'], batch['tgt_output'], batch['tgt_padding_mask']

        outputs = self.model(img,
                             tgt_input,
                             tgt_key_padding_mask=tgt_padding_mask)
        #        loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)'))
        outputs = outputs.view(-1, outputs.size(2))  #flatten(0, 1)
        tgt_output = tgt_output.view(-1)  #flatten()

        loss = self.criterion(outputs, tgt_output)

        self.optimizer.zero_grad()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)

        self.optimizer.step()
        self.scheduler.step()

        loss_item = loss.item()

        return loss_item
Example #19
0
class Training:
    def __init__(self, model, device, config, name, fold_num, imsize):
        self.config = config
        self.epoch = 0
        self.base_dir = './models/'
        os.makedirs('./models', exist_ok=True)
        self.model = model
        self.best_loss = 10**5
        self.device = device
        self.name = name
        self.fold_num = fold_num
        self.imsize = imsize
        # optimize
        param_optimizer = list(self.model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.001
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.00
        }]
        self.optimizer = AdamW(self.model.parameters(), lr=config.lr)
        self.scheduler = config.SchedulerClass(self.optimizer,
                                               **config.scheduler_params)
        # Earlystopping
        self.patience = config.patience
        # GradScaler
        self.scaler = GradScaler()

    def train_one_epoch(self, train_loader):
        self.model.train()
        showloss = Showloss()

        for step, (images, targets) in tqdm(enumerate(train_loader),
                                            total=len(train_loader)):
            self.optimizer.zero_grad()

            with autocast():
                images = torch.stack(
                    images)  # 이미지들을 합쳐 Batch 생성 (default: dim=0) [B,C,H,W]
                images = images.to(self.device).float()
                batch_size = images.shape[0]
                boxes = [
                    target['bbox'].to(self.device).float()
                    for target in targets
                ]
                labels = [
                    target['cls'].to(self.device).float() for target in targets
                ]
                img_scale = torch.tensor([
                    target['img_scale'].to(self.device).float()
                    for target in targets
                ])
                img_size = torch.tensor([
                    (self.imsize, self.imsize) for target in targets
                ]).to(self.device).float()

                # update 후로 forward는 image와 target_dict를 인자로 받음
                target_res = {}
                target_res['bbox'] = boxes
                target_res['cls'] = labels
                target_res['img_scale'] = img_scale
                target_res['img_size'] = img_size

                # pred
                output = self.model(images, target_res)
                loss = output['loss']
                showloss.update(loss.detach().item(), batch_size)

            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()

        return showloss

    def val_one_epoch(self, val_loader):
        self.model.eval()
        showloss = Showloss()
        for step, (images, targets) in tqdm(enumerate(val_loader),
                                            total=len(val_loader)):
            with torch.no_grad():
                images = torch.stack(images)
                batch_size = images.shape[0]
                images = images.to(self.device).float()
                boxes = [
                    target['bbox'].to(self.device).float()
                    for target in targets
                ]
                labels = [
                    target['cls'].to(self.device).float() for target in targets
                ]
                img_scale = torch.tensor([
                    target['img_scale'].to(self.device).float()
                    for target in targets
                ])
                img_size = torch.tensor([
                    (self.imsize, self.imsize) for target in targets
                ]).to(self.device).float()

                target_res = {}
                target_res['bbox'] = boxes
                target_res['cls'] = labels
                target_res['img_scale'] = img_scale
                target_res['img_size'] = img_size

                # loss, _, _ = self.model(images, boxes, labels)
                output = self.model(images, target_res)
                loss = output['loss']
                showloss.update(loss.detach().item(), batch_size)

        return showloss

    def save(self, path):  # 모델 및 파라미터 저장
        self.model.eval()
        torch.save(
            {
                'model_state_dict': self.model.model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'scheduler_state_dict': self.scheduler.state_dict(),
                'loss': self.best_loss,  # val
                'epoch': self.epoch,
            },
            path)

    def load(self, path):
        checkpoint = torch.load(path)
        self.model.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.best_loss = checkpoint['best_loss']  # val
        self.epoch = checkpoint['epoch'] + 1

    def fit(self, train_loader, val_loader):
        early_stopping = EarlyStopping(self.patience)
        for epoch in range(self.config.n_epochs):
            print('{} / {} Epoch'.format(epoch, self.config.n_epochs))
            train_loss = self.train_one_epoch(train_loader)
            print('[Train] loss: {}'.format(train_loss.avg))
            self.save(self.base_dir +
                      '{}_{}_last.pt'.format(self.name, self.fold_num))

            val_loss = self.val_one_epoch(val_loader)
            print('[Valid] loss: {}'.format(val_loss.avg))

            if val_loss.avg < self.best_loss:
                self.best_loss = val_loss.avg
                self.save(self.base_dir +
                          '{}_{}_best.pt'.format(self.name, self.fold_num))

            # Early stopping
            early_stopping(val_loss.avg, self.best_loss)
            if early_stopping.early_stop:
                break

            if self.config.val_scheduler:
                self.scheduler.step(metrics=val_loss.avg)

            self.epoch += 1
class TrainLoop:
    def __init__(
        self,
        *,
        model,
        diffusion,
        data,
        batch_size,
        microbatch,
        lr,
        ema_rate,
        log_interval,
        save_interval,
        resume_checkpoint,
        use_fp16=False,
        fp16_scale_growth=1e-3,
        schedule_sampler=None,
        weight_decay=0.0,
        lr_anneal_steps=0,
    ):
        self.model = model
        self.diffusion = diffusion
        self.data = data
        self.batch_size = batch_size
        self.microbatch = microbatch if microbatch > 0 else batch_size
        self.lr = lr
        self.ema_rate = (
            [ema_rate]
            if isinstance(ema_rate, float)
            else [float(x) for x in ema_rate.split(",")]
        )
        self.log_interval = log_interval
        self.save_interval = save_interval
        self.resume_checkpoint = resume_checkpoint
        self.use_fp16 = use_fp16
        self.fp16_scale_growth = fp16_scale_growth
        self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
        self.weight_decay = weight_decay
        self.lr_anneal_steps = lr_anneal_steps

        self.step = 0
        self.resume_step = 0
        self.global_batch = self.batch_size * dist.get_world_size()

        self.model_params = list(self.model.parameters())
        self.master_params = self.model_params
        self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE
        self.sync_cuda = th.cuda.is_available()

        self._load_and_sync_parameters()
        if self.use_fp16:
            self._setup_fp16()

        self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay)
        if self.resume_step:
            self._load_optimizer_state()
            # Model was resumed, either due to a restart or a checkpoint
            # being specified at the command line.
            self.ema_params = [
                self._load_ema_parameters(rate) for rate in self.ema_rate
            ]
        else:
            self.ema_params = [
                copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
            ]

        if th.cuda.is_available():
            self.use_ddp = True
            self.ddp_model = DDP(
                self.model,
                device_ids=[dist_util.dev()],
                output_device=dist_util.dev(),
                broadcast_buffers=False,
                bucket_cap_mb=128,
                find_unused_parameters=False,
            )
        else:
            if dist.get_world_size() > 1:
                logger.warn(
                    "Distributed training requires CUDA. "
                    "Gradients will not be synchronized properly!"
                )
            self.use_ddp = False
            self.ddp_model = self.model

    def _load_and_sync_parameters(self):
        resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint

        if resume_checkpoint:
            self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
            if dist.get_rank() == 0:
                logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
                self.model.load_state_dict(
                    dist_util.load_state_dict(
                        resume_checkpoint, map_location=dist_util.dev()
                    )
                )

        dist_util.sync_params(self.model.parameters())

    def _load_ema_parameters(self, rate):
        ema_params = copy.deepcopy(self.master_params)

        main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
        ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
        if ema_checkpoint:
            if dist.get_rank() == 0:
                logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
                state_dict = dist_util.load_state_dict(
                    ema_checkpoint, map_location=dist_util.dev()
                )
                ema_params = self._state_dict_to_master_params(state_dict)

        dist_util.sync_params(ema_params)
        return ema_params

    def _load_optimizer_state(self):
        main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
        opt_checkpoint = bf.join(
            bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
        )
        if bf.exists(opt_checkpoint):
            logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
            state_dict = dist_util.load_state_dict(
                opt_checkpoint, map_location=dist_util.dev()
            )
            self.opt.load_state_dict(state_dict)

    def _setup_fp16(self):
        self.master_params = make_master_params(self.model_params)
        self.model.convert_to_fp16()

    def run_loop(self):
        while (
            not self.lr_anneal_steps
            or self.step + self.resume_step < self.lr_anneal_steps
        ):
            batch, cond = next(self.data)
            self.run_step(batch, cond)
            if self.step % self.log_interval == 0:
                logger.dumpkvs()
            if self.step % self.save_interval == 0:
                self.save()
                # Run for a finite amount of time in integration tests.
                if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
                    return
            self.step += 1
        # Save the last checkpoint if it wasn't already saved.
        if (self.step - 1) % self.save_interval != 0:
            self.save()

    def run_step(self, batch, cond):
        self.forward_backward(batch, cond)
        if self.use_fp16:
            self.optimize_fp16()
        else:
            self.optimize_normal()
        self.log_step()

    def forward_backward(self, batch, cond):
        zero_grad(self.model_params)
        for i in range(0, batch.shape[0], self.microbatch):
            micro = batch[i : i + self.microbatch].to(dist_util.dev())
            micro_cond = {
                k: v[i : i + self.microbatch].to(dist_util.dev())
                for k, v in cond.items()
            }
            last_batch = (i + self.microbatch) >= batch.shape[0]
            t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())

            compute_losses = functools.partial(
                self.diffusion.training_losses,
                self.ddp_model,
                micro,
                t,
                model_kwargs=micro_cond,
            )

            if last_batch or not self.use_ddp:
                losses = compute_losses()
            else:
                with self.ddp_model.no_sync():
                    losses = compute_losses()

            if isinstance(self.schedule_sampler, LossAwareSampler):
                self.schedule_sampler.update_with_local_losses(
                    t, losses["loss"].detach()
                )

            loss = (losses["loss"] * weights).mean()
            log_loss_dict(
                self.diffusion, t, {k: v * weights for k, v in losses.items()}
            )
            if self.use_fp16:
                loss_scale = 2 ** self.lg_loss_scale
                (loss * loss_scale).backward()
            else:
                loss.backward()

    def optimize_fp16(self):
        if any(not th.isfinite(p.grad).all() for p in self.model_params):
            self.lg_loss_scale -= 1
            logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
            return

        model_grads_to_master_grads(self.model_params, self.master_params)
        self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
        self._log_grad_norm()
        self._anneal_lr()
        self.opt.step()
        for rate, params in zip(self.ema_rate, self.ema_params):
            update_ema(params, self.master_params, rate=rate)
        master_params_to_model_params(self.model_params, self.master_params)
        self.lg_loss_scale += self.fp16_scale_growth

    def optimize_normal(self):
        self._log_grad_norm()
        self._anneal_lr()
        self.opt.step()
        for rate, params in zip(self.ema_rate, self.ema_params):
            update_ema(params, self.master_params, rate=rate)

    def _log_grad_norm(self):
        sqsum = 0.0
        for p in self.master_params:
            sqsum += (p.grad ** 2).sum().item()
        logger.logkv_mean("grad_norm", np.sqrt(sqsum))

    def _anneal_lr(self):
        if not self.lr_anneal_steps:
            return
        frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
        lr = self.lr * (1 - frac_done)
        for param_group in self.opt.param_groups:
            param_group["lr"] = lr

    def log_step(self):
        logger.logkv("step", self.step + self.resume_step)
        logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
        if self.use_fp16:
            logger.logkv("lg_loss_scale", self.lg_loss_scale)

    def save(self):
        def save_checkpoint(rate, params):
            state_dict = self._master_params_to_state_dict(params)
            if dist.get_rank() == 0:
                logger.log(f"saving model {rate}...")
                if not rate:
                    filename = f"model{(self.step+self.resume_step):06d}.pt"
                else:
                    filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
                with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
                    th.save(state_dict, f)

        save_checkpoint(0, self.master_params)
        for rate, params in zip(self.ema_rate, self.ema_params):
            save_checkpoint(rate, params)

        if dist.get_rank() == 0:
            with bf.BlobFile(
                bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"),
                "wb",
            ) as f:
                th.save(self.opt.state_dict(), f)

        dist.barrier()

    def _master_params_to_state_dict(self, master_params):
        if self.use_fp16:
            master_params = unflatten_master_params(
                self.model.parameters(), master_params
            )
        state_dict = self.model.state_dict()
        for i, (name, _value) in enumerate(self.model.named_parameters()):
            assert name in state_dict
            state_dict[name] = master_params[i]
        return state_dict

    def _state_dict_to_master_params(self, state_dict):
        params = [state_dict[name] for name, _ in self.model.named_parameters()]
        if self.use_fp16:
            return make_master_params(params)
        else:
            return params
Example #21
0
def train_runner(model: nn.Module,
                 model_name: str,
                 results_dir: str,
                 experiment: str = '',
                 debug: bool = False,
                 img_size: int = IMG_SIZE,
                 learning_rate: float = 1e-2,
                 fold: int = 0,
                 checkpoint: str = '',
                 epochs: int = 15,
                 batch_size: int = 8,
                 num_workers: int = 4,
                 start_epoch: int = 0,
                 save_oof: bool = False,
                 save_train_oof: bool = False,
                 gpu_number: int = 1):
    """
    Model training runner

    Args: 
        model        : PyTorch model
        model_name   : string name for model for checkpoints saving
        results_dir  : directory to save results
        experiment   : string name for naming experiments
        debug        : if True, runs the debugging on few images 
        img_size     : size of images for training 
        learning_rate: initial learning rate (default = 1e-2) 
        fold         : training fold (default = 0)
        epochs       : number of the last epochs to train
        batch_size   : number of images in batch
        num_workers  : number of workers available
        from_epoch   : number of epoch to continue training   
        save_oof     : saves oof validation predictions. Default = False 
    """
    device = torch.device(
        f'cuda:{gpu_number}' if torch.cuda.is_available() else 'cpu')
    print(device)

    # load model weights to continue training
    if checkpoint != '':
        model, ckpt = load_model(model, checkpoint)
        moiu = ckpt['valid_miou']
        loss = ckpt['valid_loss']
        start_epoch = ckpt['epoch'] + 1
        print('Loaded model from {}, epoch {}'.format(checkpoint, start_epoch))

    model.to(device)

    # creates directories for checkpoints, tensorboard and predicitons
    checkpoints_dir = f'{results_dir}rgb/checkpoints/{model_name}{experiment}'
    predictions_dir = f'{results_dir}rgb/oof/{model_name}{experiment}'
    validations_dir = f'{results_dir}rgb/oof_val/{model_name}{experiment}'
    os.makedirs(checkpoints_dir, exist_ok=True)
    os.makedirs(predictions_dir, exist_ok=True)
    os.makedirs(validations_dir, exist_ok=True)
    print('\n', model_name, '\n')

    # datasets for train and validation
    df = pd.read_csv(f'{TRAIN_DIR}folds.csv')
    df_train = df[df.fold != fold]
    df_val = df[df.fold == fold]
    print(
        f'Train images: {len(df_train.ImageId.values)}, valid images {len(df_val.ImageId.values)}'
    )

    train_dataset = RGBDataset(
        images_dir=TRAIN_RGB,
        masks_dir=TRAIN_MASKS,
        labels_df=df_train,
        img_size=img_size,
        transforms=TRANSFORMS["medium"],
        normalise=True,
        debug=debug,
    )
    valid_dataset = RGBDataset(
        images_dir=TRAIN_RGB,
        masks_dir=TRAIN_MASKS,
        labels_df=df_val,
        img_size=img_size,
        transforms=TRANSFORMS["d4"],
        normalise=True,
        debug=debug,
    )
    # dataloaders for train and validation
    dataloader_train = DataLoader(train_dataset,
                                  num_workers=num_workers,
                                  batch_size=batch_size,
                                  shuffle=True)

    dataloader_valid = DataLoader(valid_dataset,
                                  num_workers=num_workers,
                                  batch_size=batch_size,
                                  shuffle=False,
                                  drop_last=True)
    print('{} training images, {} validation images'.format(
        len(train_dataset), len(valid_dataset)))

    # optimizers and schedulers
    optimizer = AdamW(model.parameters(), lr=learning_rate)
    #optimizer = RAdam(model.parameters(), lr=learning_rate)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               mode='max',
                                               patience=2,
                                               verbose=True,
                                               factor=0.2,
                                               min_lr=1e-6)
    # load optimizer state continue training
    #if checkpoint != '':
    #    optimizer = load_optim(optimizer, checkpoint, device)

    # criteria
    criterion1 = nn.BCEWithLogitsLoss()
    criterion = BCEJaccardLoss(bce_weight=2,
                               jaccard_weight=0.5,
                               log_loss=False,
                               log_sigmoid=True)
    #criterion = JaccardLoss(log_sigmoid=True, log_loss=False)

    # basic logging
    report_batch = 200
    report_epoch = 20
    log_file = os.path.join(checkpoints_dir, f'fold_{fold}.log')
    logging.basicConfig(filename=log_file, filemode="w", level=logging.DEBUG)
    logging.info(
        f'Parameters:\n model_name: {model_name}\n, results_dir: {results_dir}\n, experiment: {experiment}\n, img_size: {img_size}\n, \
                 learning_rate: {learning_rate}\n, fold: {fold}\n, epochs: {epochs}\n, batch_size: {batch_size}\n, num_workers: {num_workers}\n, \
                 start_epoch: {start_epoch}\n, save_oof: {save_oof}\n, optimizer: {optimizer}\n, scheduler: {scheduler} \n, checkpoint: {start_epoch} \n'
    )

    train_losses, val_losses = [], []
    best_val_loss = 1e+5
    best_val_metric = 0
    # training cycle
    print("Start training")
    for epoch in range(start_epoch, start_epoch + epochs + 1):
        print("Epoch", epoch)
        epoch_losses = []
        progress_bar = tqdm(dataloader_train, total=len(dataloader_train))
        progress_bar.set_description('Epoch {}'.format(epoch))
        with torch.set_grad_enabled(
                True):  # --> sometimes people write it, idk
            for batch_num, (img, target, _) in enumerate(progress_bar):
                img = img.to(device)
                target = target.float().to(device)
                prediction = model(img)
                loss = criterion(prediction, target)
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 3)
                optimizer.step()
                epoch_losses.append(loss.detach().cpu().numpy())

                if batch_num and batch_num % report_batch == 0:
                    neptune.log_metric('Train loss', np.mean(epoch_losses))

        # log loss history
        print("Epoch {}, Train Loss: {}".format(epoch, np.mean(epoch_losses)))
        train_losses.append(np.mean(epoch_losses))
        neptune.log_metric('Train loss', np.mean(epoch_losses))
        logging.info(
            f'epoch: {epoch}; step: {batch_num}; loss: {np.mean(epoch_losses)} \n'
        )

        # validate model
        val_loss = validate_loss(model, dataloader_valid, criterion1, epoch,
                                 validations_dir, device)

        valid_metrics = validate(model, dataloader_valid, criterion, epoch,
                                 validations_dir, save_oof, device)
        # logging metrics
        neptune.log_metric('bce_loss_valid', val_loss)
        neptune.log_metric('loss_valid', valid_metrics['val_loss'])
        neptune.log_metric('miou_valid', valid_metrics['miou'])

        # get current learning rate
        for param_group in optimizer.param_groups:
            lr = param_group['lr']
        print(f'learning_rate: {lr}')
        logging.info(f'learning_rate: {lr}\n')
        neptune.log_metric('lr', lr)
        scheduler.step(valid_metrics['miou'])

        # save the best metric
        if valid_metrics['miou'] > best_val_metric:
            best_val_metric = valid_metrics['miou']
            # save model, optimizer and losses after every epoch
            print(
                f"Saving model with the best val metric {valid_metrics['miou']}, epoch {epoch}"
            )
            checkpoint_filename = f"{model_name}_best_val_miou.pth"
            checkpoint_filepath = os.path.join(checkpoints_dir,
                                               checkpoint_filename)
            torch.save(
                {
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'epoch': epoch,
                    'loss': np.mean(epoch_losses),
                    'valid_loss': valid_metrics['val_loss'],
                    'valid_miou': valid_metrics['miou'],
                }, checkpoint_filepath)
        # save the best loss
        if valid_metrics['val_loss'] < best_val_loss:
            best_val_loss = valid_metrics['val_loss']
            # save model, optimizer and losses after every epoch
            print(
                f"Saving model with the best val loss {valid_metrics['val_loss']}, epoch {epoch}"
            )
            checkpoint_filename = "{}_best_val_loss.pth".format(model_name)
            checkpoint_filepath = os.path.join(checkpoints_dir,
                                               checkpoint_filename)
            torch.save(
                {
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'epoch': epoch,
                    'loss': np.mean(epoch_losses),
                    'valid_loss': valid_metrics['val_loss'],
                    'valid_miou': valid_metrics['miou'],
                }, checkpoint_filepath)
        # save model, optimizer and losses after every n epoch
        elif epoch % report_epoch == 0:
            print(
                f"Saving model at epoch {epoch}, val loss {valid_metrics['val_loss']}"
            )
            checkpoint_filename = "{}_epoch_{}.pth".format(model_name, epoch)
            checkpoint_filepath = os.path.join(checkpoints_dir,
                                               checkpoint_filename)
            torch.save(
                {
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'epoch': epoch,
                    'loss': np.mean(epoch_losses),
                    'valid_loss': valid_metrics['val_loss'],
                    'valid_miou': valid_metrics['miou'],
                }, checkpoint_filepath)
Example #22
0
def run_training(opt):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    work_dir, epochs, train_batch, valid_batch, weights = \
        opt.work_dir, opt.epochs, opt.train_bs, opt.valid_bs, opt.weights

    # Directories
    last = os.path.join(work_dir, 'last.pt')
    best = os.path.join(work_dir, 'best.pt')

    # --------------------------------------
    # Setup train and validation set
    # --------------------------------------
    data = pd.read_csv(opt.train_csv)
    images_path = opt.data_dir

    n_classes = 6  # fixed coding :V

    data['class'] = data.apply(lambda row: categ[row["class"]], axis=1)

    train_loader, val_loader = prepare_dataloader(data,
                                                  opt.fold,
                                                  train_batch,
                                                  valid_batch,
                                                  opt.img_size,
                                                  opt.num_workers,
                                                  data_root=images_path)

    # if not opt.ovr_val:
    #     handwritten_data = pd.read_csv(opt.handwritten_csv)
    #     printed_data = pd.read_csv(opt.printed_csv)
    #     handwritten_data['class'] = handwritten_data.apply(lambda row: categ[row["class"]], axis =1)
    #     printed_data['class'] = printed_data.apply(lambda row: categ[row["class"]], axis =1)
    #     _, handwritten_val_loader = prepare_dataloader(
    #         handwritten_data, opt.fold, train_batch, valid_batch, opt.img_size, opt.num_workers, data_root=images_path)

    #     _, printed_val_loader = prepare_dataloader(
    #         printed_data, opt.fold, train_batch, valid_batch, opt.img_size, opt.num_workers, data_root=images_path)

    # --------------------------------------
    # Models
    # --------------------------------------

    model = Classifier(model_name=opt.model_name,
                       n_classes=n_classes,
                       pretrained=True).to(device)

    if opt.weights is not None:
        cp = torch.load(opt.weights)
        model.load_state_dict(cp['model'])

    # -------------------------------------------
    # Setup optimizer, scheduler, criterion loss
    # -------------------------------------------

    optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-6)
    scheduler = CosineAnnealingWarmRestarts(optimizer,
                                            T_0=10,
                                            T_mult=1,
                                            eta_min=1e-6,
                                            last_epoch=-1)
    scaler = GradScaler()

    loss_tr = nn.CrossEntropyLoss().to(device)
    loss_fn = nn.CrossEntropyLoss().to(device)

    # --------------------------------------
    # Setup training
    # --------------------------------------
    if os.path.exists(work_dir) == False:
        os.mkdir(work_dir)

    best_loss = 1e5
    start_epoch = 0
    best_epoch = 0  # for early stopping

    if opt.resume == True:
        checkpoint = torch.load(last)

        start_epoch = checkpoint["epoch"]
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint["scheduler"])
        best_loss = checkpoint["best_loss"]

    # --------------------------------------
    # Start training
    # --------------------------------------
    print("[INFO] Start training...")
    for epoch in range(start_epoch, epochs):
        train_one_epoch(epoch,
                        model,
                        loss_tr,
                        optimizer,
                        train_loader,
                        device,
                        scheduler=scheduler,
                        scaler=scaler)
        with torch.no_grad():
            if opt.ovr_val:
                val_loss = valid_one_epoch_overall(epoch,
                                                   model,
                                                   loss_fn,
                                                   val_loader,
                                                   device,
                                                   scheduler=None)
            else:
                val_loss = valid_one_epoch(epoch,
                                           model,
                                           loss_fn,
                                           handwritten_val_loader,
                                           printed_val_loader,
                                           device,
                                           scheduler=None)

            if val_loss < best_loss:
                best_loss = val_loss
                best_epoch = epoch
                torch.save(
                    {
                        'epoch': epoch,
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        'best_loss': best_loss
                    }, os.path.join(best))

                print('best model found for epoch {}'.format(epoch + 1))

        torch.save(
            {
                'epoch': epoch,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'best_loss': best_loss
            }, os.path.join(last))

        if epoch - best_epoch > opt.patience:
            print("Early stop achieved at", epoch + 1)
            break

    del model, optimizer, train_loader, val_loader, scheduler, scaler
    torch.cuda.empty_cache()
Example #23
0
    avg_train_loss = total_train_loss / len(validation_dataloader)

    print("  Validation Loss: {0:.2f}".format(avg_val_loss))
    print("  Validation took: {:}".format(validation_time))
    training_stats.append({
        'Avg Accuracy': avg_val_accuracy,
        'Bleu Score': avg_bleuscore,
        'Training Loss': avg_train_loss,
        'Valid. Loss': avg_val_loss,
        'Validation Time': validation_time
    })

    torch.save(
        {
            'epoch': epoch_i + 4,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'total_train_loss': total_train_loss,
            'step': len(train_dataloader),
            'training_stats': training_stats
        }, "/global/cscratch1/sd/ajaybati/model_ckptDS" + str(epoch_i + 1) +
        ".pickle")
    print(training_stats)

print("")

print("Total training took {:} (h:mm:ss)".format(format_time(time.time() -
                                                             t0)))

print("done completely")