Exemple #1
0
    def train(self, epoch):
        msg = '\nTrain at Epoch: {:d}'.format(epoch)
        print(msg)

        self.netFeat.train()
        self.netClassifier.train()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()

        for batchIdx, (inputs, targets) in enumerate(self.trainLoader):

            inputs = to_device(inputs, self.device)
            targets = to_device(targets, self.device)

            self.optimizer.zero_grad()
            outputs = self.netFeat(inputs)
            outputs = self.netClassifier(outputs)
            loss = self.criterion(outputs, targets)

            loss.backward()
            self.optimizer.step()

            acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
            losses.update(loss.item(), inputs.size()[0])
            top1.update(acc1[0].item(), inputs.size()[0])
            top5.update(acc5[0].item(), inputs.size()[0])

            msg = 'Loss: {:.3f} | Top1: {:.3f}% | Top5: {:.3f}%'.format(
                losses.avg, top1.avg, top5.avg)
            progress_bar(batchIdx, len(self.trainLoader), msg)

        return losses.avg, top1.avg, top5.avg
    def eval_epoch(self):
        """
        Evaluate the network for one epoch and return the average loss.

        Returns:
            loss (float, list(float)): list of mean losses

        """
        self.model.eval()

        if isinstance(self.valid_loader, (tuple, list)):
            iterator = zip(*self.valid_loader)
        else:
            iterator = self.valid_loader

        labels = []
        posteriors = []
        losses = []
        with torch.no_grad():
            for i_batch, batch in enumerate(iterator, 1):

                if isinstance(batch.text[0], list):
                    X = []
                    for item in batch.text[0]:
                        item_array = numpy.array(item)
                        X.append(to_device(torch.from_numpy(item_array),
                                           device=self.device,
                                           dtype=torch.from_numpy(
                                               item_array).dtype))

                else:
                    X = to_device(batch.text[0], device=self.device, dtype=batch.text[0].dtype)

                y = to_device(batch.label, device=self.device, dtype=torch.long)

                lengths = to_device(batch.text[1], device=self.device,
                                    dtype=torch.long)

                batch_losses, label, cls_logits = self.process_batch(X, lengths, y)
                labels.append(label)
                posteriors.append(cls_logits)

                # aggregate the losses into a single loss value
                loss, _losses = self.return_tensor_and_list(batch_losses)
                losses.append(_losses)

        posteriors = torch.cat(posteriors, dim=0)
        predicted = numpy.argmax(posteriors.cpu(), 1)
        # predicted = predicted.numpy()
        labels_array = numpy.array((torch.cat(labels, dim=0)).cpu())

        return numpy.array(losses).mean(axis=0), labels_array, predicted
    def evaluate_one_epoch(self, model, dataloader, prefix, debugging=False, save_csv_path=None, show_progress=False):
        """Evaluate the model for one epoch."""
        model.eval()
        tot_inp, tot_outp = [], []

        with torch.no_grad():
            total = 10 if debugging else len(dataloader)
            with tqdm(dataloader, total=total) as t:
                t.set_description(prefix)

                for i, data in enumerate(t):
                    # Input
                    data = to_device(data, self.device)
                    tot_inp.append(data)

                    # Forward
                    output = model(**data)
                    tot_outp.append(output)

                    # Break when reaching 10 iterations when debugging
                    if debugging and i == 9:
                        break

        metrics = compute_metrics_from_inputs_and_outputs(
            inputs=tot_inp, outputs=tot_outp, tokenizer=self.tokenizer, save_csv_path=save_csv_path,
            show_progress=show_progress)

        if metrics is not None:
            self._record_metrics(metrics)

            to_log = json.dumps(metrics, indent=2)
            logger.info(f"{prefix}:\n{to_log}")

        model.train()
        return
Exemple #4
0
    def validate(self, valLoader, lr=None, mode='val'):
        if mode == 'test':
            nEpisode = self.nEpisode
            self.logger.info(
                '\n\nTest mode: randomly sample {:d} episodes...'.format(
                    nEpisode))
        elif mode == 'val':
            nEpisode = len(valLoader)
            self.logger.info(
                '\n\nValidation mode: pre-defined {:d} episodes...'.format(
                    nEpisode))
            valLoader = iter(valLoader)
        else:
            raise ValueError('mode is wrong!')

        episodeAccLog = []
        top1 = AverageMeter()

        self.netFeat.eval()
        #self.netSIB.eval() # set train mode, since updating bn helps to estimate better gradient

        if lr is None:
            lr = self.optimizer.param_groups[0]['lr']

        #for batchIdx, data in enumerate(valLoader):
        for batchIdx in range(nEpisode):
            data = valLoader.getEpisode() if mode == 'test' else next(
                valLoader)
            data = to_device(data, self.device)

            SupportTensor, SupportLabel, QueryTensor, QueryLabel = \
                    data['SupportTensor'].squeeze(0), data['SupportLabel'].squeeze(0), \
                    data['QueryTensor'].squeeze(0), data['QueryLabel'].squeeze(0)

            with torch.no_grad():
                SupportFeat, QueryFeat = self.netFeat(
                    SupportTensor), self.netFeat(QueryTensor)
                SupportFeat, QueryFeat, SupportLabel = \
                        SupportFeat.unsqueeze(0), QueryFeat.unsqueeze(0), SupportLabel.unsqueeze(0)

            clsScore = self.netSIB(lr, SupportFeat, SupportLabel, QueryFeat)
            clsScore = clsScore.view(QueryFeat.size()[0] * QueryFeat.size()[1],
                                     -1)
            QueryLabel = QueryLabel.view(-1)
            acc1 = accuracy(clsScore, QueryLabel, topk=(1, ))
            top1.update(acc1[0].item(), clsScore.size()[0])

            msg = 'Top1: {:.3f}%'.format(top1.avg)
            progress_bar(batchIdx, nEpisode, msg)
            episodeAccLog.append(acc1[0].item())

        mean, ci95 = getCi(episodeAccLog)
        self.logger.info(
            'Final Perf with 95% confidence intervals: {:.3f}%, {:.3f}%'.
            format(mean, ci95))
        return mean, ci95
Exemple #5
0
    def test_single_crop(self, test_loader):
        test_loader_iterator = iter(test_loader)
        num_test_iters = len(test_loader)
        tt = tqdm(range(num_test_iters), total=num_test_iters, desc="Testing")

        aux_correct = 0
        class_correct = 0
        total = 0
        features = []

        self.model.eval()
        with torch.no_grad():
            for cur_it in tt:
                data = next(test_loader_iterator)
                if isinstance(data, list):
                    data = data[0]
                # Get the inputs
                data = to_device(data, self.device)
                imgs = data['images']
                cls_lbls = data['class_labels']
                aux_lbls = data['aux_labels']

                aux_logits, class_logits = self.model(imgs)

                _, cls_pred = class_logits.max(dim=1)
                _, aux_pred = aux_logits.max(dim=1)

                class_correct += torch.sum(cls_pred == cls_lbls.data)
                aux_correct += torch.sum(aux_pred == aux_lbls.data)
                total += imgs.size(0)

                if self.args.testing.get('save_features', False):
                    feats = class_logits.cpu().data.numpy()
                    lbls_np = cls_lbls.cpu().data.numpy()
                    lbls_np = lbls_np[:, np.newaxis]
                    features.append(np.hstack((feats, lbls_np)))

            tt.close()

        aux_acc = 100 * float(aux_correct) / total
        class_acc = 100 * float(class_correct) / total
        self.logger.info('{} aux_acc: {:.2f} %, class_acc: {:.2f} %'.format(
            self.args.exp_name, aux_acc, class_acc))

        if self.args.testing.get('save_features', False):
            feature_path = os.path.join(
                self.args.cache_dir,
                self.args.datasets.test.name + '_features.npy')
            features = np.asarray(features)
            np.save(feature_path, features)
            self.logger.info(
                'Features are saved at: {:s} '.format(feature_path))

        return aux_acc, class_acc
Exemple #6
0
    def test_multi_crop(self, test_loaders):
        num_crops = len(test_loaders)
        test_loaders_iterator = [
            iter(test_loaders[i]) for i in range(num_crops)
        ]

        num_test_iters = len(test_loaders[0])
        tt = tqdm(range(num_test_iters),
                  total=num_test_iters,
                  desc="Multi-crop test")

        aux_correct = 0
        class_correct = 0
        total = 0

        self.model.eval()
        with torch.no_grad():
            for cur_it in tt:
                aux_logits_list = []
                class_logits_list = []
                for i in range(num_crops):
                    data = next(test_loaders_iterator[i])
                    if isinstance(data, list):
                        data = data[0]

                    # Get the inputs
                    data = to_device(data, self.device)
                    imgs = data['images']
                    if i == 0:
                        imgs_size = imgs.size(0)
                        cls_lbls = data['class_labels']
                        aux_lbls = data['aux_labels']

                    aux_logits_i, class_logits_i = self.model(imgs)
                    aux_logits_list.append(aux_logits_i)
                    class_logits_list.append(class_logits_i)

                aux_logits = sum(aux_logits_list)
                class_logits = sum(class_logits_list)
                _, cls_pred = class_logits.max(dim=1)
                _, aux_pred = aux_logits.max(dim=1)

                class_correct += torch.sum(cls_pred == cls_lbls.data)
                aux_correct += torch.sum(aux_pred == aux_lbls.data)
                total += imgs_size

        tt.close()

        aux_acc = 100 * float(aux_correct) / total
        class_acc = 100 * float(class_correct) / total
        self.logger.info('{} aux_acc: {:.2f} %, class_acc: {:.2f} %'.format(
            self.args.exp_name, aux_acc, class_acc))
        return aux_acc, class_acc
Exemple #7
0
def train_model(epochs_no, model_to_train, name: str):
    device = get_default_device()

    batches_to_device(train_loader, device)
    batches_to_device(val_loader, device)
    batches_to_device(test_loader, device)

    model = to_device(model_to_train, device)

    train(epochs_no, model, train_loader, val_loader)

    torch.save(model, f'E:/dxlat/Training stats/{name}.pt')
Exemple #8
0
    def test(self, epoch):
        msg = '\nTest at Epoch: {:d}'.format(epoch)
        print(msg)

        self.netFeat.eval()
        self.netClassifierVal.eval()

        top1 = AverageMeter()

        for batchIdx, data in enumerate(self.valLoader):
            data = to_device(data, self.device)

            SupportTensor, SupportLabel, QueryTensor, QueryLabel = \
                data['SupportTensor'].squeeze(0), data['SupportLabel'].squeeze(0), \
                data['QueryTensor'].squeeze(0), data['QueryLabel'].squeeze(0)

            SupportFeat, QueryFeat = self.netFeat(SupportTensor), self.netFeat(
                QueryTensor)
            SupportFeat, QueryFeat = SupportFeat.unsqueeze(
                0), QueryFeat.unsqueeze(0)

            clsScore = self.netClassifierVal(SupportFeat, QueryFeat)
            clsScore = clsScore.view(QueryFeat.size()[1], -1)

            acc1 = accuracy(clsScore, QueryLabel, topk=(1, ))
            top1.update(acc1[0].item(), clsScore.size()[0])
            msg = 'Top1: {:.3f}%'.format(top1.avg)
            progress_bar(batchIdx, len(self.valLoader), msg)

        ## Save checkpoint.
        acc = top1.avg
        if acc > self.bestAcc:
            print('Saving Best')
            torch.save(self.netFeat.state_dict(),
                       os.path.join(self.outDir, 'netFeatBest.pth'))
            torch.save(self.netClassifier.state_dict(),
                       os.path.join(self.outDir, 'netClsBest.pth'))
            self.bestAcc = acc

        print('Saving Last')
        torch.save(self.netFeat.state_dict(),
                   os.path.join(self.outDir, 'netFeatLast.pth'))
        torch.save(self.netClassifier.state_dict(),
                   os.path.join(self.outDir, 'netClsLast.pth'))

        msg = 'Best Performance: {:.3f}'.format(self.bestAcc)
        print(msg)
        return top1.avg
Exemple #9
0
    def test(self, val_loader):
        val_loader_iterator = iter(val_loader)
        num_val_iters = len(val_loader)
        tt = tqdm(range(num_val_iters), total=num_val_iters, desc="Validating")

        aux_correct = 0
        class_correct = 0
        total = 0
        soft_labels = np.zeros((1, 2))
        true_labels = []

        self.model.eval()
        with torch.no_grad():
            for cur_it in tt:

                data = next(val_loader_iterator)
                data = to_device(data, self.device)
                imgs, cls_lbls, _, _ = data
                # Get the inputs

                logits = self.model(imgs, 'main_task')

                if self.config.save_output == True:
                    smax = nn.Softmax(dim=1)
                    smax_out = smax(logits)
                    soft_labels = np.concatenate(
                        (soft_labels, smax_out.cpu().numpy()), axis=0)
                    true_labels = np.append(true_labels,
                                            cls_lbls.cpu().numpy())

                _, cls_pred = logits.max(dim=1)
                # _, aux_pred = aux_logits.max(dim=1)

                class_correct += torch.sum(cls_pred == cls_lbls)
                # aux_correct += torch.sum(aux_pred == aux_lbls.data)
                total += imgs.size(0)

            tt.close()
        if self.config.save_output == True:
            soft_labels = soft_labels[1:, :]
            np.save('pred_cam1.npy', soft_labels)
            np.save('true_cam1.npy', true_labels)

        # aux_acc = 100 * float(aux_correct) / total
        class_acc = 100 * float(class_correct) / total
        self.logger.info('class_acc: {:.2f} %'.format(class_acc))
        return class_acc
Exemple #10
0
    def train_epoch_main_task(self, src_loader, tar_loader, epoch, print_freq):
        self.model.train()
        batch_time = AverageMeter()
        losses = AverageMeter()
        main_loss = AverageMeter()
        top1 = AverageMeter()

        for it, src_batch in enumerate(src_loader['main_task']):
            t = time.time()
            self.optimizer.zero_grad()
            src = src_batch
            src = to_device(src, self.device)
            src_imgs, src_cls_lbls = src
            self.optimizer.zero_grad()
            src_main_logits = self.model(src_imgs, 'main_task')
            src_main_loss = self.class_loss_func(src_main_logits, src_cls_lbls)
            loss = src_main_loss * self.config.loss_weight['main_task']
            main_loss.update(loss.item(), src_imgs.size(0))
            precision1_train, precision2_train = accuracy(src_main_logits,
                                                          src_cls_lbls,
                                                          topk=(1, 2))
            top1.update(precision1_train[0], src_imgs.size(0))

            loss.backward()
            self.optimizer.step()

            losses.update(loss.item(), src_imgs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - t)

            self.start_iter += 1

            if self.start_iter % print_freq == 0:
                print_string = 'Epoch {:>2} | iter {:>4} | loss:{:.3f}| acc:{:.3f}| src_main: {:.3f} |' + '|{:4.2f} s/it'
                self.logger.info(
                    print_string.format(epoch, self.start_iter, losses.avg,
                                        top1.avg, main_loss.avg,
                                        batch_time.avg))
                self.writer.add_scalar('losses/all_loss', losses.avg,
                                       self.start_iter)
                self.writer.add_scalar('losses/src_main_loss', src_main_loss,
                                       self.start_iter)
        self.scheduler.step()
        self.wandb.log({"Train Loss": main_loss.avg})
    def evaluate_one_epoch(self,
                           model,
                           dataloader,
                           prefix,
                           debugging=False,
                           show_progress=False):
        """Evaluate the model for one epoch."""
        model.eval()
        tot_inp, tot_outp = [], []

        with torch.no_grad():
            total = 10 if debugging else len(dataloader)
            with tqdm(dataloader, total=total) as t:
                t.set_description(prefix)

                for i, data in enumerate(t):
                    # Input
                    data = to_device(data, self.device)
                    tot_inp.append(data)

                    # Forward
                    output = model(**data, is_training=False)
                    tot_outp.append(output)

                    # Break when reaching 10 iterations when debugging
                    if debugging and i == 9:
                        break

        acc = compute_metrics_from_inputs_and_outputs(
            inputs=tot_inp,
            outputs=tot_outp,
            show_progress=show_progress,
            output_acc=True,
            confidence_threshold=self.config["evaluation"]
            ["confidence_threshold"])

        if acc is not None:
            self._record_metrics(acc)

            to_log = [f"{k}: {v.item():.3f}" for k, v in acc.items()]
            logger.info(f"{prefix}: {', '.join(to_log)}")

        model.train()
        return
    def test(self, val_loader):
        val_loader_iterator = iter(val_loader)
        num_val_iters = len(val_loader)
        tt = tqdm(range(num_val_iters), total=num_val_iters, desc="Validating")

        aux_correct = 0
        class_correct = 0
        aux_loss = 0
        class_loss = 0

        total = 0

        self.model.eval()
        with torch.no_grad():
            for cur_it in tt:
                
                data = next(val_loader_iterator)
                if isinstance(data, list):
                    data = data[0]
                # Get the inputs
                data = to_device(data, self.device)
                imgs = data['images']
                cls_lbls = data['class_labels']
                aux_lbls = data['aux_labels']

                aux_logits, class_logits = self.model(imgs)

                aux_loss += self.class_loss_func(aux_logits, aux_lbls)
                class_loss += self.class_loss_func(class_logits, cls_lbls)

                _, cls_pred = class_logits.max(dim=1)
                _, aux_pred = aux_logits.max(dim=1)

                class_correct += torch.sum(cls_pred == cls_lbls.data)
                aux_correct += torch.sum(aux_pred == aux_lbls.data)
                total += imgs.size(0)

            tt.close()

        aux_acc = 100.0 * float(aux_correct) / total
        class_acc = 100.0 * float(class_correct) / total
        self.logger.info('aux acc: {:.2f} %, class_acc: {:.2f} %'.format(aux_acc, class_acc))
        return aux_acc, class_acc, aux_loss, class_loss
Exemple #13
0
    #         c = (predicted == l).squeeze()
    #         for i in range(30):
    #             label = l[i]
    #             class_correct[label] += c[i].item()
    #             class_total[label] += 1
    # classes = ['buildings', 'forest', 'glacier', 'mountain', 'sea', 'street']
    # for i in range(6):
    #     print('Accuracy of %5s : %2d %%' % (
    #         classes[i], 100 * class_correct[i] / class_total[i]))

    print("resnet 50")
    checkpoint = torch.load('E:/dxlat/Training stats/model.pt')
    import copy
    res50 = copy.deepcopy(res50)
    res50.load_state_dict(checkpoint['model_state_dict'])
    res50 = to_device(res50, device)
    class_correct = list(0. for i in range(6))
    class_total = list(0. for i in range(6))
    with torch.no_grad():
        for batch in test_loader:
            i, l = batch
            i, l = i.cuda(), l.cuda()
            out = res50(i)
            _, predicted = torch.max(out, 1)
            c = (predicted == l).squeeze()
            for i in range(30):
                label = l[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1
    classes = ['buildings', 'forest', 'glacier', 'mountain', 'sea', 'street']
    for i in range(6):
Exemple #14
0
    def train(self, src_loader, tar_loader, val_loader, test_loader):

        num_batches = len(src_loader)
        print_freq = max(num_batches // self.args.training.num_print_epoch, 1)
        i_iter = self.start_iter
        start_epoch = i_iter // num_batches
        num_epochs = self.args.training.num_epochs
        best_acc = 0
        for epoch in range(start_epoch, num_epochs):
            self.model.train()
            batch_time = AverageMeter()
            losses = AverageMeter()

            # adjust learning rate
            self.scheduler.step()

            for it, (src_batch, tar_batch) in enumerate(zip(src_loader, itertools.cycle(tar_loader))):
                t = time.time()

                self.optimizer.zero_grad()
                if isinstance(src_batch, list):
                    src = src_batch[0] # data, dataset_idx
                else:
                    src = src_batch
                src = to_device(src, self.device)
                src_imgs = src['images']
                src_cls_lbls = src['class_labels']
                src_aux_lbls = src['aux_labels']

                self.optimizer.zero_grad()

                src_aux_logits, src_class_logits = self.model(src_imgs)
                src_aux_loss = self.class_loss_func(src_aux_logits, src_aux_lbls)

                # If true, the network will only try to classify the non scrambled images
                if self.args.training.only_non_scrambled:
                    src_class_loss = self.class_loss_func(
                            src_class_logits[src_aux_lbls == 0], src_cls_lbls[src_aux_lbls == 0])
                else:
                    src_class_loss = self.class_loss_func(src_class_logits, src_cls_lbls)

                tar = to_device(tar_batch, self.device)
                tar_imgs = tar['images']
                tar_aux_lbls = tar['aux_labels']
                tar_aux_logits, tar_class_logits = self.model(tar_imgs)
                tar_aux_loss = self.class_loss_func(tar_aux_logits, tar_aux_lbls)
                tar_entropy_loss = self.entropy_loss(tar_class_logits[tar_aux_lbls==0])

                loss = src_class_loss + src_aux_loss * self.args.training.src_aux_weight
                loss += tar_aux_loss * self.args.training.tar_aux_weight
                loss += tar_entropy_loss * self.args.training.tar_entropy_weight

                loss.backward()
                self.optimizer.step()

                losses.update(loss.item(), src_imgs.size(0))

                # measure elapsed time
                batch_time.update(time.time() - t)

                i_iter += 1

                if i_iter % print_freq == 0:
                    print_string = 'Epoch {:>2} | iter {:>4} | src_class: {:.3f} | src_aux: {:.3f} | tar_entropy: {:.3f} | tar_aux: {:.3f} |{:4.2f} s/it'
                    self.logger.info(print_string.format(epoch, i_iter,
                        src_aux_loss.item(),
                        src_class_loss.item(),
                        tar_entropy_loss.item(),
                        tar_aux_loss.item(),
                        batch_time.avg))
                    self.writer.add_scalar('losses/src_class_loss', src_class_loss, i_iter)
                    self.writer.add_scalar('losses/src_aux_loss', src_aux_loss, i_iter)
                    self.writer.add_scalar('losses/tar_entropy_loss', tar_entropy_loss, i_iter)
                    self.writer.add_scalar('losses/tar_aux_loss', tar_aux_loss, i_iter)

            del loss, src_class_loss, src_aux_loss, tar_aux_loss, tar_entropy_loss
            del src_aux_logits, src_class_logits
            del tar_aux_logits, tar_class_logits

            # validation
            self.save(self.args.model_dir, i_iter)

            if val_loader is not None:
                self.logger.info('validating...')
                aux_acc, class_acc = self.test(val_loader)
                self.writer.add_scalar('val/aux_acc', aux_acc, i_iter)
                self.writer.add_scalar('val/class_acc', class_acc, i_iter)

            if test_loader is not None:
                self.logger.info('testing...')
                aux_acc, class_acc = self.test(test_loader)
                self.writer.add_scalar('test/aux_acc', aux_acc, i_iter)
                self.writer.add_scalar('test/class_acc', class_acc, i_iter)
                if class_acc > best_acc:
                    best_acc = class_acc
                    # todo copy current model to best model
                self.logger.info('Best testing accuracy: {:.2f} %'.format(best_acc))

        self.logger.info('Best testing accuracy: {:.2f} %'.format(best_acc))
        self.logger.info('Finished Training.')
Exemple #15
0
    def test(self, val_loader):
        val_loader_iterator = iter(val_loader)
        num_val_iters = len(val_loader)
        tt = tqdm(range(num_val_iters), total=num_val_iters, desc="Validating")
        loss = AverageMeter()
        kk = 1
        aux_correct = 0
        class_correct = 0
        total = 0
        if self.config.dataset == 'kather':
            soft_labels = np.zeros((1, 9))
        if self.config.dataset == 'oscc' or self.config.dataset == 'cam':
            soft_labels = np.zeros((1, 2))
        true_labels = []
        self.model.eval()
        with torch.no_grad():
            for cur_it in tt:
                data = next(val_loader_iterator)
                data = to_device(data, self.device)
                imgs, cls_lbls = data
                # Get the inputs
                logits = self.model(imgs, 'main_task')
                test_loss = self.class_loss_func(logits, cls_lbls)
                loss.update(test_loss.item(), imgs.size(0))
                if self.config.save_output == True:
                    smax = nn.Softmax(dim=1)
                    smax_out = smax(logits)
                    soft_labels = np.concatenate(
                        (soft_labels, smax_out.cpu().numpy()), axis=0)
                    true_labels = np.append(true_labels,
                                            cls_lbls.cpu().numpy())
                    pred_trh = smax_out.cpu().numpy()[:, 1]
                    pred_trh[pred_trh >= 0.5] = 1
                    pred_trh[pred_trh < 0.5] = 0
                    compare = cls_lbls.cpu().numpy() - pred_trh

                    kk += 1
                _, cls_pred = logits.max(dim=1)

                class_correct += torch.sum(cls_pred == cls_lbls)
                total += imgs.size(0)

            tt.close()
        self.wandb.log({"Test Loss": loss.avg})
        # if self.config.save_output == True:
        soft_labels = soft_labels[1:, :]
        if self.config.dataset == 'oscc' or self.config.dataset == 'cam':
            AUC = calculate_stat(soft_labels,
                                 true_labels,
                                 2,
                                 self.config.class_names,
                                 type='binary',
                                 thresh=0.5)
        if self.config.dataset == 'kather':
            AUC = calculate_stat(soft_labels,
                                 true_labels,
                                 9,
                                 self.config.class_names,
                                 type='multi',
                                 thresh=0.5)
        class_acc = 100 * float(class_correct) / total
        self.logger.info('class_acc: {:.2f} %'.format(class_acc))
        self.wandb.log({"Test acc": class_acc, "Test AUC": 100 * AUC})
        return class_acc, AUC
Exemple #16
0
    def train_epoch_all_tasks(self, src_loader, tar_loader, epoch, print_freq):
        self.model.train()
        batch_time = AverageMeter()
        losses = AverageMeter()
        main_loss = AverageMeter()
        top1 = AverageMeter()
        start_steps = epoch * len(tar_loader['main_task'])
        total_steps = self.config.num_epochs * len(tar_loader['main_task'])

        max_num_iter_src = max([
            len(src_loader[task_name]) for task_name in self.config.task_names
        ])
        for it in range(max_num_iter_src):
            t = time.time()

            # this is based on DANN paper
            p = float(it + start_steps) / total_steps
            alpha = 2. / (1. + np.exp(-10 * p)) - 1

            self.optimizer.zero_grad()

            src = next(iter(src_loader['main_task']))
            tar = next(iter(tar_loader['main_task']))
            src = to_device(src, self.device)
            tar = to_device(tar, self.device)
            src_imgs, src_cls_lbls = src
            tar_imgs, _ = tar

            src_main_logits = self.model(src_imgs, 'main_task')
            src_main_loss = self.class_loss_func(src_main_logits, src_cls_lbls)
            loss = src_main_loss * self.config.loss_weight['main_task']
            main_loss.update(loss.item(), src_imgs.size(0))
            tar_main_logits = self.model(tar_imgs, 'main_task')
            tar_main_loss = self.entropy_loss(tar_main_logits)
            loss += tar_main_loss
            tar_aux_loss = {}
            src_aux_loss = {}

            #TO DO: separating dataloaders and iterate over tasks
            for task in self.config.task_names:
                if self.config.tasks[task]['type'] == 'classification_adapt':
                    r = torch.randperm(src_imgs.size()[0] + tar_imgs.size()[0])
                    src_tar_imgs = torch.cat((src_imgs, tar_imgs), dim=0)
                    src_tar_imgs = src_tar_imgs[r, :, :, :]
                    src_tar_img = src_tar_imgs[:src_imgs.size()[0], :, :, :]
                    src_tar_lbls = torch.cat((torch.zeros(
                        (src_imgs.size()[0])), torch.ones(
                            (tar_imgs.size()[0]))),
                                             dim=0)
                    src_tar_lbls = src_tar_lbls[r]
                    src_tar_lbls = src_tar_lbls[:src_imgs.size()[0]]
                    src_tar_lbls = src_tar_lbls.long().cuda()
                    src_tar_logits = self.model(src_tar_img,
                                                'domain_classifier', alpha)
                    tar_aux_loss['domain_classifier'] = self.class_loss_func(
                        src_tar_logits, src_tar_lbls)
                    loss += tar_aux_loss[
                        'domain_classifier'] * self.config.loss_weight[
                            'domain_classifier']
                if self.config.tasks[task]['type'] == 'classification_self':
                    src = next(iter(src_loader[task]))
                    tar = next(iter(tar_loader[task]))
                    src = to_device(src, self.device)
                    tar = to_device(tar, self.device)
                    src_aux_imgs, src_aux_lbls = src
                    tar_aux_imgs, tar_aux_lbls = tar
                    tar_aux_logits = self.model(tar_aux_imgs, task)
                    src_aux_logits = self.model(src_aux_imgs, task)
                    tar_aux_loss[task] = self.class_loss_func(
                        tar_aux_logits, tar_aux_lbls)
                    src_aux_loss[task] = self.class_loss_func(
                        src_aux_logits, src_aux_lbls)
                    loss += src_aux_loss[task] * self.config.loss_weight[
                        task]  # todo: magnification weight
                    loss += tar_aux_loss[task] * self.config.loss_weight[
                        task]  # todo: main task weight
                if self.config.tasks[task]['type'] == 'pixel_self':
                    src = next(iter(src_loader[task]))
                    tar = next(iter(tar_loader[task]))
                    src = to_device(src, self.device)
                    tar = to_device(tar, self.device)
                    src_aux_imgs, src_aux_lbls = src
                    tar_aux_imgs, tar_aux_lbls = tar
                    tar_aux_mag_logits = self.model(tar_aux_imgs, task)
                    src_aux_mag_logits = self.model(src_aux_imgs, task)
                    tar_aux_loss[task] = self.pixel_loss(
                        tar_aux_mag_logits, tar_aux_lbls)
                    src_aux_loss[task] = self.pixel_loss(
                        src_aux_mag_logits, src_aux_lbls)
                    loss += src_aux_loss[task] * self.config.loss_weight[
                        task]  # todo: magnification weight
                    loss += tar_aux_loss[task] * self.config.loss_weight[task]

            precision1_train, precision2_train = accuracy(src_main_logits,
                                                          src_cls_lbls,
                                                          topk=(1, 2))
            top1.update(precision1_train[0], src_imgs.size(0))
            loss.backward()
            self.optimizer.step()
            losses.update(loss.item(), src_imgs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - t)
            self.start_iter += 1
            if self.start_iter % print_freq == 0:
                printt = ''
                for task_name in self.config.aux_task_names:
                    if task_name == 'domain_classifier':
                        printt = printt + ' | tar_aux_' + task_name + ': {:.3f} |'
                    else:
                        printt = printt + 'src_aux_' + task_name + ': {:.3f} | tar_aux_' + task_name + ': {:.3f}'
                print_string = 'Epoch {:>2} | iter {:>4} | loss:{:.3f} |  acc: {:.3f} | src_main: {:.3f} |' + printt + '{:4.2f} s/it'
                src_aux_loss_all = [
                    loss.item() for loss in src_aux_loss.values()
                ]
                tar_aux_loss_all = [
                    loss.item() for loss in tar_aux_loss.values()
                ]
                self.logger.info(
                    print_string.format(epoch, self.start_iter, losses.avg,
                                        top1.avg, main_loss.avg,
                                        *src_aux_loss_all, *tar_aux_loss_all,
                                        batch_time.avg))
                self.writer.add_scalar('losses/all_loss', losses.avg,
                                       self.start_iter)
                self.writer.add_scalar('losses/src_main_loss', src_main_loss,
                                       self.start_iter)
                for task_name in self.config.aux_task_names:
                    if task_name == 'domain_classifier':
                        # self.writer.add_scalar('losses/src_aux_loss_'+task_name, src_aux_loss[task_name], i_iter)
                        self.writer.add_scalar(
                            'losses/tar_aux_loss_' + task_name,
                            tar_aux_loss[task_name], self.start_iter)
                    else:
                        self.writer.add_scalar(
                            'losses/src_aux_loss_' + task_name,
                            src_aux_loss[task_name], self.start_iter)
                        self.writer.add_scalar(
                            'losses/tar_aux_loss_' + task_name,
                            tar_aux_loss[task_name], self.start_iter)
            self.scheduler.step()
        self.wandb.log({"Train Loss": main_loss.avg})
Exemple #17
0
    def train(self, trainLoader, valLoader, lr=None, coeffGrad=0.0):
        """
        Run one epoch on train-set.

        :param trainLoader: the dataloader of train-set
        :type trainLoader: class `TrainLoader`
        :param valLoader: the dataloader of val-set
        :type valLoader: class `ValLoader`
        :param float lr: learning rate for synthetic GD
        :param float coeffGrad: deprecated
        """
        bestAcc, ci = self.validate(valLoader, lr, 'test')
        self.logger.info(
            'Acc improved over validation set from 0% ---> {:.3f} +- {:.3f}%'.
            format(bestAcc, ci))

        self.netSIB.train()
        self.netFeat.eval()

        losses = AverageMeter()
        top1 = AverageMeter()
        history = {'trainLoss': [], 'trainAcc': [], 'valAcc': []}

        for episode in range(self.nbIter):
            data = trainLoader.getBatch()
            data = to_device(data, self.device)

            with torch.no_grad():
                SupportTensor, SupportLabel, QueryTensor, QueryLabel = \
                        data['SupportTensor'], data['SupportLabel'], data['QueryTensor'], data['QueryLabel']
                nC, nH, nW = SupportTensor.shape[2:]

                # SupportFeat = self.netFeat(SupportTensor.reshape(-1, nC, nH, nW))
                SupportFeat = self.pretrain.get_features(
                    SupportTensor.reshape(-1, nC, nH, nW))
                SupportFeat = SupportFeat.view(self.batchSize, -1, self.nFeat)

                # QueryFeat = self.netFeat(QueryTensor.reshape(-1, nC, nH, nW))
                QueryFeat = self.pretrain.get_features(
                    QueryTensor.reshape(-1, nC, nH, nW))
                QueryFeat = QueryFeat.view(self.batchSize, -1, self.nFeat)

            if lr is None:
                lr = self.optimizer.param_groups[0]['lr']

            self.optimizer.zero_grad()

            clsScore = self.netSIB(SupportFeat, SupportLabel, QueryFeat, lr)
            clsScore = clsScore.view(QueryFeat.shape[0] * QueryFeat.shape[1],
                                     -1)

            # Inductive
            '''
            clsScore = torch.zeros(QueryFeat.shape[1], 5).cuda()
            for i in range(QueryFeat.shape[1]):
                singleScore = self.netSIB(SupportFeat, SupportLabel, QueryFeat[:, i, :].unsqueeze(1), lr)
                clsScore[i] = singleScore[0][0]
            '''

            QueryLabel = QueryLabel.view(-1)

            if coeffGrad > 0:
                loss, gradLoss = self.compute_grad_loss(clsScore, QueryLabel)
                loss = loss + gradLoss * coeffGrad
            else:
                loss = self.criterion(clsScore, QueryLabel)

            loss.backward()
            self.optimizer.step()

            acc1 = accuracy(clsScore, QueryLabel, topk=(1, ))
            top1.update(acc1[0].item(), clsScore.shape[0])
            losses.update(loss.item(), QueryFeat.shape[1])
            msg = 'Loss: {:.3f} | Top1: {:.3f}% '.format(losses.avg, top1.avg)
            if coeffGrad > 0:
                msg = msg + '| gradLoss: {:.3f}%'.format(gradLoss.item())
            progress_bar(episode, self.nbIter, msg)

            if episode % 1000 == 999:
                acc, _ = self.validate(valLoader, lr, 'test')

                if acc > bestAcc:
                    msg = 'Acc improved over validation set from {:.3f}% ---> {:.3f}%'.format(
                        bestAcc, acc)
                    self.logger.info(msg)

                    bestAcc = acc
                    self.logger.info('Saving Best')
                    torch.save(
                        {
                            'lr': lr,
                            'netFeat': self.netFeat.state_dict(),
                            'SIB': self.netSIB.state_dict(),
                            'nbStep': self.nStep,
                        }, os.path.join(self.outDir, 'netSIBBest.pth'))

                self.logger.info('Saving Last')
                torch.save(
                    {
                        'lr': lr,
                        'netFeat': self.netFeat.state_dict(),
                        'SIB': self.netSIB.state_dict(),
                        'nbStep': self.nStep,
                    }, os.path.join(self.outDir, 'netSIBLast.pth'))

                msg = 'Iter {:d}, Train Loss {:.3f}, Train Acc {:.3f}%, Val Acc {:.3f}%, Best Acc {:.3f}'.format(
                    episode, losses.avg, top1.avg, acc, bestAcc)
                self.logger.info(msg)
                self.write_output_message(msg)
                history['trainLoss'].append(losses.avg)
                history['trainAcc'].append(top1.avg)
                history['valAcc'].append(acc)

                losses = AverageMeter()
                top1 = AverageMeter()

        return bestAcc, acc, history
    def train_one_epoch(self, model, dataloader, optimizer, scheduler, num_epochs, max_grad_norm=None,
                        debugging=False):
        """Train the model for one epoch."""
        model.train()
        timer = Timer()

        print(
            ("{:25}" + "|" + "{:^15}" * (3 + len(self.early_stopping_metrics)) + "|").format(
                "", "l1_loss", "l2_loss", "l3_loss", *self.early_stopping_metrics)
        )

        total = 10 if debugging else len(dataloader)
        with tqdm(dataloader, total=total) as t:
            if num_epochs is not None:
                description = f"Training ({self.epoch}/{num_epochs})"
            else:
                description = "Training"
            t.set_description(description)

            for i, data in enumerate(t):
                timer.start()

                data = to_device(data, self.device)
                optimizer.zero_grad()

                # Forward
                output = model(**data)
                losses = output["losses"]

                # Calculate batch metrics
                metric = compute_metrics_from_inputs_and_outputs(
                    inputs=data, outputs=output, tokenizer=self.tokenizer, save_csv_path=None)
                losses.update(metric)

                # Update tqdm with training information
                to_tqdm = []  # update tqdm
                for loss_type in ["l1_cls_loss", "l2_cls_loss", "l3_cls_loss", *self.early_stopping_metrics]:
                    loss_n = losses[loss_type]

                    if isinstance(loss_n, torch.Tensor) and torch.isnan(loss_n):
                        to_tqdm.append("nan")
                    else:
                        to_tqdm.append(f"{loss_n.item():.3f}")

                des = (
                    "{:25}" + "|" + "{:^15}" * (3 + len(self.early_stopping_metrics)) + "|"
                ).format(description, *to_tqdm)
                t.set_description(des)

                # Backward
                losses["total_loss"].backward()
                if max_grad_norm is not None:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
                optimizer.step()
                if scheduler is not None:
                    scheduler.step()

                timer.end()

                # Break when reaching 10 iterations when debugging
                if debugging and i == 9:
                    break

        logger.info(f"{description} took {timer.get_total_time():.2f}s.")
        return
    def validate(self, valLoader, mode='val'):
        if mode == 'test':
            nEpisode = self.nEpisode
            self.logger.info(
                '\n\nTest mode: randomly sample {:d} episodes...'.format(
                    nEpisode))
        elif mode == 'val':
            nEpisode = len(valLoader)
            self.logger.info(
                '\n\nValidation mode: pre-defined {:d} episodes...'.format(
                    nEpisode))
            valLoader = iter(valLoader)
        else:
            raise ValueError('mode is wrong!')

        episodeAccLog = []
        top1 = AverageMeter()

        self.netFeat.eval()
        self.netRefine.eval()
        self.netClassifier.eval()

        #for batchIdx, data in enumerate(valLoader):
        for batchIdx in range(nEpisode):
            data = valLoader.getEpisode() if mode == 'test' else next(
                valLoader)
            data = to_device(data, self.device)

            SupportTensor, SupportLabel, QueryTensor, QueryLabel = \
                    data['SupportTensor'].squeeze(0), data['SupportLabel'].squeeze(0), \
                    data['QueryTensor'].squeeze(0), data['QueryLabel'].squeeze(0)

            with torch.no_grad():
                SupportFeat, QueryFeat = self.netFeat(
                    SupportTensor), self.netFeat(QueryTensor)
                SupportFeat, QueryFeat, SupportLabel = \
                        SupportFeat.unsqueeze(0), QueryFeat.unsqueeze(0), SupportLabel.unsqueeze(0)
                nbSupport, nbQuery = SupportFeat.size()[1], QueryFeat.size()[1]

                feat = torch.cat((SupportFeat, QueryFeat), dim=1)
                refine_feat = self.netRefine(feat)
                refine_feat = feat + refine_feat
                refine_support, refine_query = refine_feat.narrow(
                    1, 0,
                    nbSupport), refine_feat.narrow(1, nbSupport, nbQuery)
                clsScore = self.netClassifier(refine_support, SupportLabel,
                                              refine_query)
                clsScore = clsScore.squeeze(0)
            QueryLabel = QueryLabel.view(-1)
            acc1 = accuracy(clsScore, QueryLabel, topk=(1, ))
            top1.update(acc1[0].item(), clsScore.size()[0])

            msg = 'Top1: {:.3f}%'.format(top1.avg)
            progress_bar(batchIdx, nEpisode, msg)
            episodeAccLog.append(acc1[0].item())

        mean, ci95 = getCi(episodeAccLog)
        self.logger.info(
            'Final Perf with 95% confidence intervals: {:.3f}%, {:.3f}%'.
            format(mean, ci95))

        self.netRefine.train()
        self.netClassifier.train()
        return mean, ci95
    def train(self, src_loader, tar_loader, val_loader, test_loader):

        num_batches = len(src_loader)
        print_freq = 1 #max(num_batches // self.args.training.num_print_epoch, 1)
        i_iter = self.start_iter
        start_epoch = i_iter // num_batches
        num_epochs = self.args.training.num_epochs
        best_acc = 0
        for epoch in range(start_epoch, start_epoch + num_epochs):
            self.model.train()
            batch_time = AverageMeter()
            losses = AverageMeter()

            # adjust learning rate
            self.scheduler.step()

            for it, (src_batch, tar_batch) in enumerate(zip(itertools.cycle(src_loader), tar_loader)):
                t = time.time()

                self.optimizer.zero_grad()

                if isinstance(src_batch, list):
                    src = src_batch[0] # data, dataset_idx
                else:
                    src = src_batch
                src = to_device(src, self.device)

                src_imgs = src['images']
                src_cls_lbls = src['class_labels']
                src_aux_lbls = src['aux_labels']

                src_aux_logits, src_class_logits = self.model(src_imgs)
                src_aux_loss = self.class_loss_func(src_aux_logits, src_aux_lbls)

                _, cls_pred = src_class_logits.max(dim=1)
                _, aux_pred = src_aux_logits.max(dim=1)

                # If true, the network will only try to classify the non scrambled images
                if self.args.training.only_non_scrambled:
                    src_class_loss = self.class_loss_func(
                            src_class_logits[src_aux_lbls == 0], src_cls_lbls[src_aux_lbls == 0])
                    true_pos_class = torch.sum(cls_pred[src_aux_lbls == 0] \
                         == src_cls_lbls[src_aux_lbls == 0]).to(dtype=torch.float)
                    num_samples_accu = src_cls_lbls[src_aux_lbls == 0].size(0)
                else:
                    src_class_loss = self.class_loss_func(src_class_logits, src_cls_lbls)
                    true_pos_class = torch.sum(cls_pred == src_cls_lbls).to(dtype=torch.float)
                    num_samples_accu = src_imgs.size(0)

                tar = to_device(tar_batch, self.device)
                tar_imgs = tar['images']
                tar_aux_lbls = tar['aux_labels']
                tar_aux_logits, _ = self.model(tar_imgs)
                tar_aux_loss = self.class_loss_func(tar_aux_logits, tar_aux_lbls)

                _, tar_aux_pred = tar_aux_logits.max(dim=1)

                loss = src_class_loss + src_aux_loss * self.args.training.src_aux_weight + \
                    tar_aux_loss * self.args.training.tar_aux_weight

                aux_loss = src_aux_loss + tar_aux_loss

                loss.backward()
                self.optimizer.step()

                losses.update(loss.item(), src_imgs.size(0))
                
                class_acc = true_pos_class/num_samples_accu
                aux_acc = ((torch.sum(aux_pred == src_aux_lbls).to(dtype=torch.float) +
                    torch.sum(tar_aux_pred == tar_aux_lbls).to(dtype=torch.float))/(src_imgs.size(0) + tar_imgs.size(0)))

                # measure elapsed time
                batch_time.update(time.time() - t)

                i_iter += 1

                if i_iter % print_freq == 0:
                    print_string = 'Epoch {:>2} | iter {:>4} | aux_loss : {:.3f} | class_loss : {:.3f} |class_acc: {:.3f} | aux_acc: {:.3f} | {:4.2f} s/it'
                    self.logger.info(print_string.format(epoch, i_iter,
                        aux_loss.item(),
                        src_class_loss.item(),
                        class_acc.item(),
                        aux_acc.item(),
                        batch_time.avg))
                    self.writer.add_scalar('losses/src_class_loss', src_class_loss, i_iter)
                    self.writer.add_scalar('losses/aux_loss', aux_loss, i_iter)
                    wandb.log({"epoch": epoch,
                            "Iterations": i_iter,
                            "train_class_loss": src_class_loss.item(),
                            "train_aux_loss": aux_loss.item(),
                            "train_class_acc" : class_acc.item(),
                            "train_aux_acc" : aux_acc.item()})

            del loss, src_class_loss, src_aux_loss, aux_loss
            del src_aux_logits, src_class_logits

            # validation
            if val_loader is not None:
                self.logger.info('validating...')
                aux_acc, class_acc, aux_loss, class_loss = self.test(val_loader)
                self.writer.add_scalar('val/aux_acc', aux_acc, i_iter)
                self.writer.add_scalar('val/class_acc', class_acc, i_iter)
                wandb.log({"epoch": epoch,
                            "val_class_loss": class_loss.item(),
                            "val_aux_loss": aux_loss.item(),
                            "val_class_acc" : class_acc,
                            "val_aux_acc" : aux_acc})
                 # save best model
                if class_acc > best_acc:
                    best_acc = class_acc
                    self.save(self.args.model_dir, epoch)
                    
            if test_loader is not None:
                self.logger.info('testing...')
                aux_acc, class_acc = self.test(test_loader)
                self.writer.add_scalar('test/aux_acc', aux_acc, i_iter)
                self.writer.add_scalar('test/class_acc', class_acc, i_iter)
                if class_acc > best_acc:
                    best_acc = class_acc
                    # todo copy current model to best model
                self.logger.info('Best testing accuracy: {:.2f} %'.format(best_acc))

        self.logger.info('Best testing accuracy: {:.2f} %'.format(best_acc))
        self.logger.info('Finished Training.')
Exemple #21
0
    def train(self, src_loader, tar_loader, val_loader, test_loader):

        num_batches = len(src_loader)
        print_freq = max(num_batches // self.config.training_num_print_epoch,
                         1)
        i_iter = self.start_iter
        start_epoch = i_iter // num_batches
        num_epochs = self.config.num_epochs
        best_acc = 0
        for epoch in range(start_epoch, num_epochs):
            self.model.train()
            batch_time = AverageMeter()
            losses = AverageMeter()

            # adjust learning rate
            self.scheduler.step()

            for it, src_batch in enumerate(src_loader):
                t = time.time()

                self.optimizer.zero_grad()
                src = src_batch
                src = to_device(src, self.device)
                src_imgs, src_cls_lbls, src_aux_imgs, src_aux_lbls = src

                self.optimizer.zero_grad()

                src_main_logits = self.model(src_imgs, 'main_task')
                src_main_loss = self.class_loss_func(src_main_logits,
                                                     src_cls_lbls)
                loss = src_main_loss * self.config.loss_weight['main_task']

                loss.backward()
                self.optimizer.step()

                losses.update(loss.item(), src_imgs.size(0))

                # measure elapsed time
                batch_time.update(time.time() - t)

                i_iter += 1

                print_string = 'Epoch {:>2} | iter {:>4} | loss:{:.3f}| src_main: {:.3f} |' + '|{:4.2f} s/it'

                self.logger.info(
                    print_string.format(epoch, i_iter, losses.avg,
                                        src_main_loss.item(), batch_time.avg))
                self.writer.add_scalar('losses/all_loss', losses.avg, i_iter)
                self.writer.add_scalar('losses/src_main_loss', src_main_loss,
                                       i_iter)
            # del loss, src_class_loss, src_aux_loss, tar_aux_loss, tar_entropy_loss
            # del src_aux_logits, src_class_logits
            # del tar_aux_logits, tar_class_logits

            # validation
            self.save(self.config.model_dir, i_iter)

            if val_loader is not None:
                self.logger.info('validating...')
                class_acc = self.test(val_loader)
                # self.writer.add_scalar('val/aux_acc', class_acc, i_iter)
                self.writer.add_scalar('val/class_acc', class_acc, i_iter)
                if class_acc > best_acc:
                    best_acc = class_acc
                    self.save(self.config.best_model_dir, i_iter)
                    # todo copy current model to best model
                self.logger.info(
                    'Best testing accuracy: {:.2f} %'.format(best_acc))

            if test_loader is not None:
                self.logger.info('testing...')
                class_acc = self.test(test_loader)
                # self.writer.add_scalar('test/aux_acc', class_acc, i_iter)
                self.writer.add_scalar('test/class_acc', class_acc, i_iter)
                if class_acc > best_acc:
                    best_acc = class_acc
                    # todo copy current model to best model
                self.logger.info(
                    'Best testing accuracy: {:.2f} %'.format(best_acc))

        self.logger.info('Best testing accuracy: {:.2f} %'.format(best_acc))
        self.logger.info('Finished Training.')
    def train(self, src_loader, tar_loader, val_loader, test_loader):

        num_batches = len(src_loader)
        print('Number of batches: %d' % num_batches)
        print_freq = max(num_batches // self.args.training.num_print_epoch, 1)
        i_iter = self.start_iter
        start_epoch = i_iter // num_batches
        num_epochs = self.args.training.num_epochs
        best_acc = 0

        for epoch in range(start_epoch, num_epochs):
            self.model.train()
            batch_time = AverageMeter()
            losses = AverageMeter()

            for it, (src_batch, tar_batch) in enumerate(
                    zip(src_loader, itertools.cycle(tar_loader))):
                t = time.time()

                if isinstance(src_batch, list):
                    src = src_batch[0]  # data, dataset_idx
                else:
                    src = src_batch
                src = to_device(src, self.device)
                src_imgs = src['images']
                src_cls_lbls = src['class_labels']

                self.optimizer = inv_lr_scheduler(
                    self.optimizer,
                    i_iter,
                    lr=self.args.training.optimizer.lr,
                    wd=self.args.training.optimizer.weight_decay)

                self.optimizer.zero_grad()

                src_feats, src_class_logits = self.model(src_imgs)
                src_class_loss = self.class_loss_func(src_class_logits,
                                                      src_cls_lbls)

                tar = to_device(tar_batch, self.device)
                tar_imgs = tar['images']
                tar_feats, tar_class_logits = self.model(tar_imgs)

                features = torch.cat((src_feats, tar_feats), dim=0)
                outputs = torch.cat((src_class_logits, tar_class_logits),
                                    dim=0)
                softmax_out = nn.Softmax(dim=1)(outputs)

                if self.args.method == 'cdan':
                    transfer_loss = cdan_loss([features, softmax_out],
                                              self.adv_model, None, None,
                                              self.random_layer)
                elif self.args.method == 'cdan+e':
                    entropy = compute_entropy(softmax_out)
                    transfer_loss = cdan_loss([features, softmax_out],
                                              self.adv_model, entropy,
                                              calc_coeff(i_iter),
                                              self.random_layer)
                elif self.args.method == 'dann':
                    transfer_loss = dann_loss(features, self.adv_model)
                else:
                    raise ValueError('Method cannot be recognized.')

                loss = src_class_loss + transfer_loss * self.args.training.transfer_loss_weight

                loss.backward()
                self.optimizer.step()

                losses.update(loss.item(), src_imgs.size(0))

                # measure elapsed time
                batch_time.update(time.time() - t)

                i_iter += 1

                # adjust learning rate
                # self.scheduler.step()

                #for param_group in self.optimizer.param_groups:
                #    print(i_iter, param_group['lr'], param_group['weight_decay'])

                if i_iter % print_freq == 0:
                    print_string = 'Epoch {:>2} | iter {:>4} | class_loss: {:.3f} | transfer_loss: {:.3f} | {:4.2f} s/it'
                    self.logger.info(
                        print_string.format(epoch, i_iter,
                                            src_class_loss.item(),
                                            transfer_loss.item(),
                                            batch_time.avg))
                    self.writer.add_scalar('losses/src_class_loss',
                                           src_class_loss, i_iter)
                    self.writer.add_scalar('losses/transfer_loss',
                                           transfer_loss, i_iter)

            del loss, src_class_loss, transfer_loss
            del src_class_logits
            del tar_class_logits

            if test_loader:
                self.logger.info('testing...')
                class_acc = self.test(test_loader)
                self.writer.add_scalar('test/class_acc', class_acc, i_iter)
                if class_acc > best_acc:
                    best_acc = class_acc
                    self.save(self.args.model_dir, i_iter, is_best='True')

                self.logger.info(
                    'Best testing accuracy: {:.2f} %'.format(best_acc))

        self.logger.info('Best testing accuracy: {:.2f} %'.format(best_acc))
        self.logger.info('Finished Training.')
Exemple #23
0
                          train_config,
                          sort=False,
                          drop_last=False)
    train_loader = DataLoader(
        train_dataset,
        batch_size=train_config["optimizer"]["batch_size"] * 4,
        shuffle=True,
        collate_fn=train_dataset.collate_fn,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=train_config["optimizer"]["batch_size"],
        shuffle=False,
        collate_fn=val_dataset.collate_fn,
    )

    n_batch = 0
    for batchs in train_loader:
        for batch in batchs:
            to_device(batch, device)
            n_batch += 1
    print("Training set  with size {} is composed of {} batches.".format(
        len(train_dataset), n_batch))

    n_batch = 0
    for batchs in val_loader:
        for batch in batchs:
            to_device(batch, device)
            n_batch += 1
    print("Validation set  with size {} is composed of {} batches.".format(
        len(val_dataset), n_batch))
    def train_epoch(self):
        """
        Train the network for one epoch and return the average loss.
        * This will be a pessimistic approximation of the true loss
        of the network, as the loss of the first batches will be higher
        than the true.

        Returns:
            loss (float, list(float)): list of mean losses

        """
        self.model.train()
        losses = []

        self.epoch += 1
        epoch_start = time.time()

        if isinstance(self.train_loader, (tuple, list)):
            iterator = zip(*self.train_loader)
        else:
            iterator = self.train_loader

        for i_batch, batch in enumerate(iterator, 1):

            self.step += 1

            for optimizer in self.optimizers:
                optimizer.zero_grad()

            if isinstance(batch.text[0], list):
                X = []
                for item in batch.text[0]:
                    item_array = numpy.array(item)
                    X.append(to_device(torch.from_numpy(item_array),
                                       device=self.device,
                                       dtype=torch.from_numpy(item_array).dtype))

            else:
                X = to_device(batch.text[0], device=self.device, dtype=batch.text[0].dtype)

            y = to_device(batch.label, device=self.device, dtype=torch.long)

            lengths = to_device(batch.text[1], device=self.device, dtype=torch.long)

            batch_loss, _, _ = self.process_batch(X, lengths, y)

            # aggregate the losses into a single loss value
            loss_sum, loss_list = self.return_tensor_and_list(batch_loss)
            losses.append(loss_list)

            # back-propagate
            loss_sum.backward()

            # if self.clip is not None:
            #     for optimizer in self.optimizers:
            #         clip_grad_norm_((p for group in optimizer.param_groups
            #                          for p in group['params']), self.clip)

            # update weights
            for optimizer in self.optimizers:
                optimizer.step()

            if self.step % self.log_interval == 0:
                self.progress_log = epoch_progress(self.epoch, i_batch,
                                                   self.train_batch_size,
                                                   self.train_set_size,
                                                   epoch_start)

            for c in self.batch_end_callbacks:
                if callable(c):
                    c(i_batch, batch_loss)

        return numpy.array(losses).mean(axis=0)
Exemple #25
0
    def validate(self, valLoader, lr=None, mode='val'):
        """
        Run one epoch on val-set.
        :param valLoader: the dataloader of val-set
        :type valLoader: class `ValLoader`
        :param float lr: learning rate for synthetic GD
        :param string mode: 'val' or 'train'
        """
        if mode == 'test':
            nEpisode = self.nEpisode
            self.logger.info(
                '\n\nTest mode: randomly sample {:d} episodes...'.format(
                    nEpisode))
        elif mode == 'val':
            nEpisode = len(valLoader)
            self.logger.info(
                '\n\nValidation mode: pre-defined {:d} episodes...'.format(
                    nEpisode))
            valLoader = iter(valLoader)
        else:
            raise ValueError('mode is wrong!')

        episodeAccLog = []
        top1 = AverageMeter()

        self.netFeat.eval()
        #self.netSIB.eval() # set train mode, since updating bn helps to estimate better gradient

        if lr is None:
            lr = self.optimizer.param_groups[0]['lr']

        #for batchIdx, data in enumerate(valLoader):
        # nEpisode = 1
        for batchIdx in range(nEpisode):
            data = valLoader.getEpisode() if mode == 'test' else next(
                valLoader)
            data = to_device(data, self.device)

            SupportTensor, SupportLabel, QueryTensor, QueryLabel = \
                    data['SupportTensor'].squeeze(0), data['SupportLabel'].squeeze(0), \
                    data['QueryTensor'].squeeze(0), data['QueryLabel'].squeeze(0)

            with torch.no_grad():
                # SupportFeat, QueryFeat = self.netFeat(SupportTensor), self.netFeat(QueryTensor)
                SupportFeat, QueryFeat = self.pretrain.get_features(
                    SupportTensor), self.pretrain.get_features(QueryTensor)
                SupportFeat, QueryFeat, SupportLabel = \
                        SupportFeat.unsqueeze(0), QueryFeat.unsqueeze(0), SupportLabel.unsqueeze(0)

            clsScore = self.netSIB(SupportFeat, SupportLabel, QueryFeat, lr)
            clsScore = clsScore.view(QueryFeat.shape[0] * QueryFeat.shape[1],
                                     -1)

            # Inductive
            '''
            clsScore = torch.zeros(QueryFeat.shape[1], 5).cuda()
            for i in range(QueryFeat.shape[1]):
                singleScore = self.netSIB(SupportFeat, SupportLabel, QueryFeat[:, i, :].unsqueeze(1), lr)
                clsScore[i] = singleScore[0][0]
            '''

            QueryLabel = QueryLabel.view(-1)

            if self.davg:
                # diff_scores = self.calc_diff_scores(self.pretrain, SupportFeat.squeeze(0), QueryFeat.squeeze(0), SupportLabel.squeeze(0), QueryLabel)  # cosine similarity
                diff_scores = self._evaluate_hardness_logodd(
                    self.pretrain,
                    SupportFeat.squeeze(0), QueryFeat.squeeze(0),
                    SupportLabel.squeeze(0), QueryLabel)  # logodd
            else:
                diff_scores = None
            acc1 = accuracy(clsScore,
                            QueryLabel,
                            topk=(1, ),
                            diff_scores=diff_scores)
            top1.update(acc1[0].item(), clsScore.shape[0])

            msg = 'Top1: {:.3f}%'.format(top1.avg)
            progress_bar(batchIdx, nEpisode, msg)
            episodeAccLog.append(acc1[0].item())

        mean, ci95 = getCi(episodeAccLog)
        msg = 'Final Perf with 95% confidence intervals: {:.3f}%, {:.3f}%'.format(
            mean, ci95)
        self.logger.info(msg)
        self.write_output_message(msg)
        return mean, ci95
Exemple #26
0
    def train(self, trainLoader, valLoader, lr=None, coeffGrad=0.0):
        bestAcc, ci = self.validate(valLoader, lr)
        self.logger.info(
            'Acc improved over validation set from 0% ---> {:.3f} +- {:.3f}%'.
            format(bestAcc, ci))

        self.netSIB.train()
        self.netFeat.eval()

        losses = AverageMeter()
        top1 = AverageMeter()
        history = {'trainLoss': [], 'trainAcc': [], 'valAcc': []}

        for episode in range(self.nbIter):
            data = trainLoader.getBatch()
            data = to_device(data, self.device)

            with torch.no_grad():
                SupportTensor, SupportLabel, QueryTensor, QueryLabel = \
                        data['SupportTensor'], data['SupportLabel'], data['QueryTensor'], data['QueryLabel']

                SupportFeat = self.netFeat(SupportTensor.contiguous().view(
                    -1, 3, self.inputW, self.inputH))
                QueryFeat = self.netFeat(QueryTensor.contiguous().view(
                    -1, 3, self.inputW, self.inputH))

                SupportFeat, QueryFeat = SupportFeat.contiguous().view(self.batchSize, -1, self.nFeat), \
                        QueryFeat.view(self.batchSize, -1, self.nFeat)

            if lr is None:
                lr = self.optimizer.param_groups[0]['lr']

            self.optimizer.zero_grad()

            clsScore = self.netSIB(lr, SupportFeat, SupportLabel, QueryFeat)
            clsScore = clsScore.view(QueryFeat.size()[0] * QueryFeat.size()[1],
                                     -1)
            QueryLabel = QueryLabel.view(-1)

            if coeffGrad > 0:
                loss, gradLoss = self.compute_grad_loss(clsScore, QueryLabel)
                loss = loss + gradLoss * coeffGrad
            else:
                loss = self.criterion(clsScore, QueryLabel)

            loss.backward()
            self.optimizer.step()

            acc1 = accuracy(clsScore, QueryLabel, topk=(1, ))
            top1.update(acc1[0].item(), clsScore.size()[0])
            losses.update(loss.item(), QueryFeat.size()[1])
            msg = 'Loss: {:.3f} | Top1: {:.3f}% '.format(losses.avg, top1.avg)
            if coeffGrad > 0:
                msg = msg + '| gradLoss: {:.3f}%'.format(gradLoss.item())
            progress_bar(episode, self.nbIter, msg)

            if episode % 1000 == 999:
                acc, _ = self.validate(valLoader, lr)

                if acc > bestAcc:
                    msg = 'Acc improved over validation set from {:.3f}% ---> {:.3f}%'.format(
                        bestAcc, acc)
                    self.logger.info(msg)

                    bestAcc = acc
                    self.logger.info('Saving Best')
                    torch.save(
                        {
                            'lr': lr,
                            'netFeat': self.netFeat.state_dict(),
                            'SIB': self.netSIB.state_dict(),
                            'nbStep': self.nStep,
                        }, os.path.join(self.outDir, 'netSIBBest.pth'))

                self.logger.info('Saving Last')
                torch.save(
                    {
                        'lr': lr,
                        'netFeat': self.netFeat.state_dict(),
                        'SIB': self.netSIB.state_dict(),
                        'nbStep': self.nStep,
                    }, os.path.join(self.outDir, 'netSIBLast.pth'))

                msg = 'Iter {:d}, Train Loss {:.3f}, Train Acc {:.3f}%, Val Acc {:.3f}%'.format(
                    episode, losses.avg, top1.avg, acc)
                self.logger.info(msg)
                history['trainLoss'].append(losses.avg)
                history['trainAcc'].append(top1.avg)
                history['valAcc'].append(acc)

                losses = AverageMeter()
                top1 = AverageMeter()

        return bestAcc, acc, history
Exemple #27
0
    def LrWarmUp(self, totalIter, lr):
        msg = '\nLearning rate warming up'
        print(msg)

        self.optimizer = torch.optim.SGD(itertools.chain(
            *[self.netFeat.parameters(),
              self.netClassifier.parameters()]),
                                         1e-7,
                                         momentum=0.9,
                                         weight_decay=5e-4,
                                         nesterov=True)

        nbIter = 0
        lrUpdate = lr
        valTop1 = 0

        while nbIter < totalIter:
            self.netFeat.train()
            self.netClassifier.train()
            losses = AverageMeter()
            top1 = AverageMeter()
            top5 = AverageMeter()

            for batchIdx, (inputs, targets) in enumerate(self.trainLoader):
                nbIter += 1
                if nbIter == totalIter:
                    break

                lrUpdate = nbIter / float(totalIter) * lr
                for g in self.optimizer.param_groups:
                    g['lr'] = lrUpdate

                inputs = to_device(inputs, self.device)
                targets = to_device(targets, self.device)

                self.optimizer.zero_grad()
                outputs = self.netFeat(inputs)
                outputs = self.netClassifier(outputs)
                loss = self.criterion(outputs, targets)

                loss.backward()
                self.optimizer.step()

                acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
                losses.update(loss.item(), inputs.size()[0])
                top1.update(acc1[0].item(), inputs.size()[0])
                top5.update(acc5[0].item(), inputs.size()[0])

                msg = 'Loss: {:.3f} | Lr : {:.5f} | Top1: {:.3f}% | Top5: {:.3f}%'.format(
                    losses.avg, lrUpdate, top1.avg, top5.avg)
                progress_bar(batchIdx, len(self.trainLoader), msg)

        with torch.no_grad():
            valTop1 = self.test(0)

        self.optimizer = torch.optim.SGD(itertools.chain(
            *[self.netFeat.parameters(),
              self.netClassifier.parameters()]),
                                         lrUpdate,
                                         momentum=0.9,
                                         weight_decay=5e-4,
                                         nesterov=True)

        self.lrScheduler = MultiStepLR(self.optimizer,
                                       milestones=self.milestones,
                                       gamma=0.1)
        return valTop1
Exemple #28
0
    def train(self, src_loader, tar_loader, val_loader, test_loader):

        num_batches = len(src_loader)
        print('Number of batches: %d' % num_batches)
        print_freq = max(num_batches // self.args.training.num_print_epoch, 1)
        i_iter = self.start_iter
        start_epoch = i_iter // num_batches
        num_epochs = self.args.training.num_epochs
        best_acc = 0

        for epoch in range(start_epoch, num_epochs):
            self.model.train()
            batch_time = AverageMeter()
            losses = AverageMeter()

            for it, (src_batch, tar_batch) in enumerate(
                    zip(src_loader, itertools.cycle(tar_loader))):
                t = time.time()

                if isinstance(src_batch, list):
                    src = src_batch[0]  # data, dataset_idx
                else:
                    src = src_batch
                src = to_device(src, self.device)
                src_imgs = src['images']
                src_imgs_ori = src['images_ori']
                src_cls_lbls = src['class_labels']
                src_aux_lbls = src['aux_labels']

                #self.optimizer = inv_lr_scheduler(self.optimizer, i_iter,
                #        lr=self.args.training.optimizer.lr, wd=self.args.training.optimizer.weight_decay)

                self.optimizer.zero_grad()

                src_aux_logits, src_class_logits = self.model(src_imgs)
                src_aux_loss = self.class_loss_func(src_aux_logits,
                                                    src_aux_lbls)

                # If true, the network will only try to classify the non scrambled images
                if self.args.training.get('only_non_scrambled'):
                    src_class_loss = self.class_loss_func(
                        src_class_logits[src_aux_lbls == 0],
                        src_cls_lbls[src_aux_lbls == 0])
                else:
                    src_class_loss = self.class_loss_func(
                        src_class_logits, src_cls_lbls)

                src_kld_loss = torch.tensor(0.0)
                if self.args.training.get('src_kld_weight'):
                    _, ori_logits = self.model(src_imgs_ori)
                    aug_logits = src_class_logits
                    src_kld_loss = self.KLD_loss_func(
                        F.log_softmax(aug_logits, 1),
                        F.softmax(ori_logits.detach(), 1))

                tar = to_device(tar_batch, self.device)
                tar_imgs = tar['images']
                tar_imgs_ori = tar['images_ori']
                tar_aux_lbls = tar['aux_labels']
                tar_aux_logits, tar_class_logits = self.model(tar_imgs)
                tar_aux_loss = self.class_loss_func(tar_aux_logits,
                                                    tar_aux_lbls)
                tar_entropy_loss = self.entropy_loss(
                    tar_class_logits[tar_aux_lbls == 0])

                tar_kld_loss = torch.tensor(0.0)
                if self.args.training.get('tar_kld_weight'):
                    _, ori_logits = self.model(tar_imgs_ori)
                    aug_logits = tar_class_logits
                    tar_kld_loss = self.KLD_loss_func(
                        F.log_softmax(aug_logits, 1),
                        F.softmax(ori_logits.detach(), 1))

                loss = src_class_loss + src_aux_loss * self.args.training.src_aux_weight

                if self.args.training.get('src_kld_weight'):
                    loss += self.args.training.src_kld_weight * src_kld_loss

                loss += tar_aux_loss * self.args.training.tar_aux_weight
                loss += tar_entropy_loss * self.args.training.tar_entropy_weight

                if self.args.training.get('tar_kld_weight'):
                    loss += self.args.training.tar_kld_weight * tar_kld_loss

                loss.backward()
                self.optimizer.step()

                losses.update(loss.item(), src_imgs.size(0))

                # measure elapsed time
                batch_time.update(time.time() - t)

                i_iter += 1

                if i_iter % print_freq == 0:
                    print_string = 'Epoch {:>2} | iter {:>4} | src_class: {:.3f} | src_aux: {:.3f} | tar_aux: {:.3f} | tar_entropy: {:.3f} | src_kld:{:.3f} | tar_kld:{:.3f} | {:4.2f} s/it'
                    self.logger.info(
                        print_string.format(epoch, i_iter,
                                            src_class_loss.item(),
                                            src_aux_loss.item(),
                                            tar_aux_loss.item(),
                                            tar_entropy_loss.item(),
                                            src_kld_loss.item(),
                                            tar_kld_loss.item(),
                                            batch_time.avg))
                    self.writer.add_scalar('losses/src_class_loss',
                                           src_class_loss, i_iter)
                    self.writer.add_scalar('losses/src_aux_loss', src_aux_loss,
                                           i_iter)
                    self.writer.add_scalar('losses/tar_aux_loss', tar_aux_loss,
                                           i_iter)
                    self.writer.add_scalar('losses/tar_entropy_loss',
                                           tar_entropy_loss, i_iter)

            # adjust learning rate
            self.scheduler.step()

            del loss, src_class_loss, src_aux_loss, tar_aux_loss, tar_entropy_loss
            del src_aux_logits, src_class_logits
            del tar_aux_logits, tar_class_logits

            # validation
            # if val_loader:
            # self.logger.info('validating...')
            # aux_acc, class_acc = self.test(val_loader)
            # self.writer.add_scalar('val/aux_acc', aux_acc, i_iter)
            # self.writer.add_scalar('val/class_acc', class_acc, i_iter)

            if test_loader:
                self.logger.info('testing...')
                aux_acc, class_acc = self.test(test_loader)
                self.writer.add_scalar('test/aux_acc', aux_acc, i_iter)
                self.writer.add_scalar('test/class_acc', class_acc, i_iter)
                if class_acc > best_acc:
                    best_acc = class_acc
                    self.save(self.args.model_dir, i_iter, is_best='True')

                self.logger.info(
                    'Best testing accuracy: {:.2f} %'.format(best_acc))

        self.logger.info('Best testing accuracy: {:.2f} %'.format(best_acc))
        self.logger.info('Finished Training.')
    def train(self, trainLoader, valLoader, lr=None, coeffGrad=0.0):
        bestAcc, ci = self.validate(valLoader)
        self.logger.info(
            'Acc improved over validation set from 0% ---> {:.3f} +- {:.3f}%'.
            format(bestAcc, ci))

        self.netRefine.train()
        self.netFeat.eval()

        losses = AverageMeter()
        top1 = AverageMeter()
        history = {'trainLoss': [], 'trainAcc': [], 'valAcc': []}

        for episode in range(self.nbIter):
            data = trainLoader.getBatch()
            data = to_device(data, self.device)

            with torch.no_grad():
                SupportTensor, SupportLabel, QueryTensor, QueryLabel = \
                        data['SupportTensor'], data['SupportLabel'], data['QueryTensor'], data['QueryLabel']

                SupportFeat = self.netFeat(SupportTensor.contiguous().view(
                    -1, 3, self.inputW, self.inputH))
                QueryFeat = self.netFeat(QueryTensor.contiguous().view(
                    -1, 3, self.inputW, self.inputH))

                SupportFeat, QueryFeat = SupportFeat.contiguous().view(self.batchSize, -1, self.nFeat), \
                        QueryFeat.contiguous().view(self.batchSize, -1, self.nFeat)

            self.optimizer.zero_grad()

            nbSupport, nbQuery = SupportFeat.size()[1], QueryFeat.size()[1]
            feat = torch.cat((SupportFeat, QueryFeat), dim=1)
            refine_feat = self.netRefine(feat)
            refine_feat = feat + refine_feat
            refine_support, refine_query = refine_feat.narrow(
                1, 0, nbSupport), refine_feat.narrow(1, nbSupport, nbQuery)
            clsScore = self.netClassifier(refine_support, SupportLabel,
                                          refine_query)

            clsScore = clsScore.view(
                refine_query.size()[0] * refine_query.size()[1], -1)
            QueryLabel = QueryLabel.view(-1)

            loss = self.criterion(clsScore, QueryLabel)

            loss.backward()
            self.optimizer.step()

            acc1 = accuracy(clsScore, QueryLabel, topk=(1, ))
            top1.update(acc1[0].item(), clsScore.size()[0])
            losses.update(loss.item(), QueryFeat.size()[1])
            msg = 'Loss: {:.3f} | Top1: {:.3f}% '.format(losses.avg, top1.avg)
            progress_bar(episode, self.nbIter, msg)

            if episode % 1000 == 999:
                acc, _ = self.validate(valLoader)

                if acc > bestAcc:
                    msg = 'Acc improved over validation set from {:.3f}% ---> {:.3f}%'.format(
                        bestAcc, acc)
                    self.logger.info(msg)

                    bestAcc = acc
                    self.logger.info('Saving Best')
                    torch.save(
                        {
                            'lr': lr,
                            'netFeat': self.netFeat.state_dict(),
                            'netRefine': self.netRefine.state_dict(),
                            'netClassifier': self.netClassifier.state_dict(),
                        }, os.path.join(self.outDir, 'netBest.pth'))

                self.logger.info('Saving Last')
                torch.save(
                    {
                        'lr': lr,
                        'netFeat': self.netFeat.state_dict(),
                        'netRefine': self.netRefine.state_dict(),
                        'netClassifier': self.netClassifier.state_dict(),
                    }, os.path.join(self.outDir, 'netLast.pth'))

                msg = 'Iter {:d}, Train Loss {:.3f}, Train Acc {:.3f}%, Val Acc {:.3f}%'.format(
                    episode, losses.avg, top1.avg, acc)
                self.logger.info(msg)
                history['trainLoss'].append(losses.avg)
                history['trainAcc'].append(top1.avg)
                history['valAcc'].append(acc)

                losses = AverageMeter()
                top1 = AverageMeter()

        return bestAcc, acc, history
    def train_one_epoch(self,
                        model,
                        dataloader,
                        optimizer,
                        scheduler,
                        num_epochs,
                        max_grad_norm=None,
                        debugging=False):
        """Train the model for one epoch."""
        model.train()
        timer = Timer()

        print(("{:25}" + "|" + "{:^45}" + "|" + "{:^45}" + "|" + "{:^45}" +
               "|").format("", "food", "service", "price"))
        print(("{:25}" + "|" + "{:^15}" * 3 + "|" + "{:^15}" * 3 + "|" +
               "{:^15}" * 3 + "|").format("", "mse_loss", "existence_loss",
                                          "acc", "mse_loss", "existence_loss",
                                          "acc", "mse_loss", "existence_loss",
                                          "acc"))

        total = 10 if debugging else len(dataloader)
        with tqdm(dataloader, total=total) as t:
            if num_epochs is not None:
                description = f"Training ({self.epoch}/{num_epochs})"
            else:
                description = "Training"
            t.set_description(description)

            for i, data in enumerate(t):
                timer.start()

                data = to_device(data, self.device)
                optimizer.zero_grad()

                # Forward
                output = model(**data, is_training=True)
                losses = output["losses"]

                # Calculate batch accuracy
                acc = compute_metrics_from_inputs_and_outputs(
                    inputs=data,
                    outputs=output,
                    output_acc=True,
                    confidence_threshold=self.config["evaluation"]
                    ["confidence_threshold"])
                losses.update(acc)

                # Update tqdm with training information
                to_tqdm = []  # update tqdm
                for name in ["food", "service", "price"]:
                    for loss_type in ["score_loss", "existence_loss", "acc"]:
                        n = f"{name}_{loss_type}"
                        loss_n = losses[n]

                        if (not isinstance(
                                loss_n, torch.Tensor)) or torch.isnan(loss_n):
                            to_tqdm.append("nan")
                        else:
                            to_tqdm.append(f"{loss_n.item():.3f}")

                des = ("{:25}" + "|" + "{:^15}" * 3 + "|" + "{:^15}" * 3 +
                       "|" + "{:^15}" * 3 + "|").format(description, *to_tqdm)
                t.set_description(des)

                # Backward
                losses["total_loss"].backward()
                if max_grad_norm is not None:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   max_grad_norm)
                optimizer.step()
                if scheduler is not None:
                    scheduler.step()

                timer.end()

                # Break when reaching 10 iterations when debugging
                if debugging and i == 9:
                    break

        logger.info(f"{description} took {timer.get_total_time():.2f}s.")
        return