def trainer(self, train_loader, criterion, optimizer, epoch, cum_epochs,
                update):
        logging.info('\n' + '-' * 200 + '\n' + '\t' * 10 + 'TRAINING\n')
        losses = AverageMeter()
        self.train()
        for i, ((ida, xa, xp), (idn, xn)) in enumerate(train_loader):
            xa, xp, xn = xa.unsqueeze(1), xp.unsqueeze(1), xn.unsqueeze(1)
            xa, xp, xn = xa.to(self.device), xp.to(self.device), xn.to(
                self.device)
            outa, outp, outn = self.forward(xa, xp, xn)
            optimizer.zero_grad()
            loss = criterion(outa, outp, outn)
            losses.update(loss.item(), xa.size(0))
            loss.backward()
            optimizer.step()

            if (i + 1) % self.show_every == 0:
                logging.info('\tEpoch: [{0}][{1}/{2}]\t'
                             'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
                                 cum_epochs + epoch + 1,
                                 i + 1,
                                 len(train_loader),
                                 loss=losses))
            if update:
                score_update = update_cos(outa, outp, outn, 1.5 * self.margin)
                train_loader.dataset.update(score_update, ida, idn)
        if update:
            logging.info('\n{}'.format(
                np.array_str(train_loader.dataset.probability_matrix,
                             precision=2,
                             suppress_small=True)))
            train_loader.dataset.reset()

        logging.info('\n' + '-' * 200)
Example #2
0
    def trainer(self, train_loader, criterion, optimizer, epoch, cum_epochs):
        logging.info('\n' + '-' * 200 + '\n' + '\t' * 10 + 'TRAINING\n')
        losses = AverageMeter()
        self.train()
        for i, (x1, x2, label) in enumerate(train_loader):
            x1, x2 = x1.unsqueeze(1), x2.unsqueeze(1)
            if not self.normalize:
                label = 2 * label - 1
            x1, x2, label = x1.to(self.device), x2.to(self.device), label.to(
                self.device)
            out1, out2 = self.forward(x1, x2)
            optimizer.zero_grad()
            loss = criterion(out1.squeeze(), out2.squeeze(),
                             label.squeeze().float())
            losses.update(loss.item(), x1.size(0))
            loss.backward()
            optimizer.step()
            if (i + 1) % self.show_every == 0:
                logging.info('\tEpoch: [{0}][{1}/{2}]\t'
                             'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
                                 cum_epochs + epoch + 1,
                                 i + 1,
                                 len(train_loader),
                                 loss=losses))

        logging.info('\n' + '-' * 200)
Example #3
0
def train_fn(train_loader,
             model,
             optimizer,
             criterion,
             scheduler,
             device,
             print_log: int = 10):
    losses = AverageMeter()
    hammings = AverageMeter()
    model.train()
    for i, (_, image, plant, disease) in enumerate(train_loader):
        image = image.to(device)
        combined = multi_label_tensors_to_single_label_tensor(plant, disease)
        combined = combined.to(device)

        optimizer.zero_grad()

        outputs = model(image)
        loss = criterion(outputs, combined)

        loss.backward()
        optimizer.step()

        _, preds = torch.max(outputs, 1)
        acc = cal_hamming_loss(plant, disease, preds)

        if (i + 1) % print_log == 0:
            print(f'  Train - loss: {loss.item():.4f} hamming loss: {acc}')

        losses.update(loss, image.size(0))
        hammings.update(acc, image.size(0))

    scheduler.step()
Example #4
0
def train(train_loader, model, optimizer, criterion, device, scheduler):
    model.train()
    losses = AverageMeter()
    hammings = AverageMeter()
    train_loader = tqdm(train_loader, total=len(train_loader))
    for _, image, plant, disease in train_loader:
        image = image.to(device)
        combined = multi_label_tensors_to_single_label_tensor(plant, disease)
        combined = combined.to(device)

        optimizer.zero_grad()

        outputs = model(image)
        loss = criterion(outputs, combined)

        loss.backward()
        optimizer.step()

        _, preds = torch.max(outputs, 1)
        hamming = cal_hamming_loss(plant, disease, preds)

        losses.update(loss.item(), image.size(0))
        hammings.update(hamming.item(), image.size(0))

        train_loader.set_postfix(loss=losses.avg, hamming=hammings.avg)

    scheduler.step()
Example #5
0
    def validation_step(self, summary_dev):
        """Summary
        Extract the batch of datapoints and return the predicted logits in validation step
        Args:
            summary_dev (TYPE): Description

        Returns:
            TYPE: Description
        """
        losses = AverageMeter()
        torch.set_grad_enabled(False)
        self.model.eval()

        output_ = np.array([])
        target_ = np.array([])

        with torch.no_grad():
            for i, (inputs, target, _) in enumerate(self.valid_loader):
                target = target.to(self.device)
                if isinstance(inputs, tuple):
                    inputs = tuple([
                        e.to(self.device) if type(e) == torch.Tensor else e
                        for e in inputs
                    ])
                else:
                    inputs = inputs.to(self.device)

                logits = self.forward(inputs)
                loss = self.criterion(logits, target)
                losses.update(loss.item(), target.size(0))

                if self.hparams.multi_cls:
                    output = F.softmax(logits)
                    _, output = torch.max(output, 1)
                else:
                    output = torch.sigmoid(logits)

                target = target.detach().to('cpu').numpy()
                target_ = np.concatenate(
                    (target_, target), axis=0) if len(target_) > 0 else target
                y_pred = output.detach().to('cpu').numpy()
                output_ = np.concatenate(
                    (output_, y_pred), axis=0) if len(output_) > 0 else y_pred

        summary_dev['loss'] = losses.avg
        return summary_dev, output_, target_
Example #6
0
    def evaluate(self, loader):
        val_bar = tqdm(loader)
        avg_psnr = AverageMeter()
        avg_ssim = AverageMeter()

        recon_images = []
        gt_images = []
        input_images = []

        for data in val_bar:
            self.set_input(data)
            self.forward()

            if self.opts.wr_L1 > 0:
                psnr_recon = psnr(complex_abs_eval(self.recon),
                                  complex_abs_eval(self.tag_image_full))
                avg_psnr.update(psnr_recon)

                ssim_recon = ssim(complex_abs_eval(self.recon)[0,0,:,:].cpu().numpy(),
                                  complex_abs_eval(self.tag_image_full)[0,0,:,:].cpu().numpy())
                avg_ssim.update(ssim_recon)

                recon_images.append(self.recon[0].cpu())
                gt_images.append(self.tag_image_full[0].cpu())
                input_images.append(self.tag_image_sub[0].cpu())

            message = 'PSNR: {:4f} '.format(avg_psnr.avg)
            message += 'SSIM: {:4f} '.format(avg_ssim.avg)
            val_bar.set_description(desc=message)

        self.psnr_recon = avg_psnr.avg
        self.ssim_recon = avg_ssim.avg

        self.results = {}
        if self.opts.wr_L1 > 0:
            self.results['recon'] = torch.stack(recon_images).squeeze().numpy()
            self.results['gt'] = torch.stack(gt_images).squeeze().numpy()
            self.results['input'] = torch.stack(input_images).squeeze().numpy()
Example #7
0
def valid_fn(data_loader, model, optimizer, criterion, scheduler, device):
    with torch.no_grad():
        model.eval()
        losses = AverageMeter()
        hammings = AverageMeter()
        for _, image, plant, disease in data_loader:
            image = image.to(device)
            combined = multi_label_tensors_to_single_label_tensor(
                plant, disease)
            combined = combined.to(device)

            outputs = model(image)
            loss = criterion(outputs, combined)

            _, preds = torch.max(outputs, 1)
            acc = cal_hamming_loss(plant, disease, preds)

            losses.update(loss, image.size(0))
            hammings.update(acc, image.size(0))

        print(f'  Valid - loss: {losses.avg:.4f} hamming loss: {hammings.avg}')

    return losses.avg
Example #8
0
def evaluate(data_loader, model, criterion, device):
    model.eval()
    losses = AverageMeter()
    hammings = AverageMeter()
    data_loader = tqdm(data_loader, total=len(data_loader))
    with torch.no_grad():
        for _, image, plant, disease in data_loader:
            image = image.to(device)
            combined = multi_label_tensors_to_single_label_tensor(
                plant, disease)
            combined = combined.to(device)

            outputs = model(image)
            loss = criterion(outputs, combined)

            _, preds = torch.max(outputs, 1)
            hamming = cal_hamming_loss(plant, disease, preds)

            losses.update(loss.item(), image.size(0))
            hammings.update(hamming.item(), image.size(0))

            data_loader.set_postfix(loss=losses.avg, hamming=hammings.avg)

    return hammings.avg
    def evaluate(self, loader):
        val_bar = tqdm(loader)
        avg_psnr = AverageMeter()
        avg_ssim = AverageMeter()
        avg_mse = AverageMeter()

        pred_images = []
        gt_images = []
        gt_inp_images = []

        for data in val_bar:
            self.set_input(data)
            self.forward(self.IH)

            psnr_ = psnr(self.IT_fake + 1, self.IT + 1)
            mse_ = mse(self.IT_fake + 1, self.IT + 1)
            ssim_ = ssim(self.IT_fake[0, 0, ...].cpu().numpy() + 1,
                         self.IT[0, 0, ...].cpu().numpy() + 1)
            avg_psnr.update(psnr_)
            avg_mse.update(mse_)
            avg_ssim.update(ssim_)

            pred_images.append(self.IT_fake[0].cpu())
            gt_images.append(self.IT[0].cpu())
            gt_inp_images.append(self.IH[0].cpu())

            message = 'PSNR: {:4f} '.format(avg_psnr.avg)
            message += 'SSIM: {:4f} '.format(avg_ssim.avg)
            message += 'MSE: {:4f} '.format(avg_mse.avg)
            val_bar.set_description(desc=message)

        self.psnr = avg_psnr.avg
        self.ssim = avg_ssim.avg
        self.mse = avg_mse.avg

        self.results = {}
        self.results['pred_IT'] = torch.stack(pred_images).squeeze().numpy()
        self.results['gt_IT'] = torch.stack(gt_images).squeeze().numpy()
        self.results['gt_IH'] = torch.stack(gt_inp_images).squeeze().numpy()
Example #10
0
def val_epoch(epoch, data_loader, model, criterion, opt, logger):
    print('validation at epoch {}'.format(epoch))

    model.eval()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    accuracies = AverageMeter()

    end_time = time.time()
    for i, (inputs, targets, video_ids) in enumerate(data_loader):
        data_time.update(time.time() - end_time)

        # if not opt.no_cuda:
        # targets = targets.cuda(async=True)
        inputs = Variable(inputs, volatile=True)
        targets = Variable(targets, volatile=True)

        if opt.save_features:
            model.module.label = video_ids[0] + str(targets.tolist()[0])

        outputs = model(inputs)
        loss = criterion(outputs, targets)
        acc = calculate_accuracy(outputs, targets)

        losses.update(loss.data[0], inputs.size(0))
        accuracies.update(acc, inputs.size(0))

        batch_time.update(time.time() - end_time)
        end_time = time.time()

        print('Epoch: [{0}][{1}/{2}]\t'
              'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
              'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
              'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
              'Acc {acc.val:.3f} ({acc.avg:.3f})'.format(epoch,
                                                         i + 1,
                                                         len(data_loader),
                                                         batch_time=batch_time,
                                                         data_time=data_time,
                                                         loss=losses,
                                                         acc=accuracies))

    logger.log({'epoch': epoch, 'loss': losses.avg, 'acc': accuracies.avg})

    return losses.avg
Example #11
0
def validate():
    bs = 256
    # create model
    model = create_model('vit_base_patch16_224',
                         pretrained=True,
                         num_classes=1000)
    criterion = nn.CrossEntropyLoss()

    dataset = create_val_dataset(root='/data/imagenet',
                                 batch_size=bs,
                                 num_workers=4,
                                 img_size=224)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    with jt.no_grad():
        input = jt.random((bs, 3, 224, 224))
        model(input)

        end = time.time()
        for batch_idx, (input, target) in enumerate(dataset):
            # dataset.display_worker_status()
            batch_size = input.shape[0]
            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss, batch_size)
            top1.update(acc1, batch_size)
            top5.update(acc5, batch_size)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if batch_idx % 10 == 0:
                # jt.sync_all(True)
                print(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                        batch_idx,
                        len(dataset),
                        batch_time=batch_time,
                        rate_avg=batch_size / batch_time.avg,
                        loss=losses,
                        top1=top1,
                        top5=top5))

            # if batch_idx>50:break

    top1a, top5a = top1.avg, top5.avg
    top1 = round(top1a, 4)
    top1_err = round(100 - top1a, 4)
    top5 = round(top5a, 4)
    top5_err = round(100 - top5a, 4)

    print(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
        top1, top1_err, top5, top5_err))
Example #12
0
    def training_step(self, summary_train, summary_dev, best_dict):
        """Summary
        Extract the batch of datapoints and return the predicted logits
        Args:
            summary_train:
            summary_dev:
            best_dict:
        Returns:
            TYPE: Description
        """
        losses = AverageMeter()
        torch.set_grad_enabled(True)
        self.model.train()
        time_now = time.time()

        for i, (inputs, target, _) in enumerate(self.train_loader):
            if isinstance(inputs, tuple):
                inputs = tuple([
                    e.to(self.device) if type(e) == torch.Tensor else e
                    for e in inputs
                ])
            else:
                inputs = inputs.to(self.device)
            target = target.to(self.device)
            self.optimizer.zero_grad()

            if self.cfg.no_jsd:
                if self.cfg.n_crops:
                    bs, n_crops, c, h, w = inputs.size()
                    inputs = inputs.view(-1, c, h, w)

                    if len(self.hparams.mixtype) > 0:
                        if self.hparams.multi_cls:
                            target = target.view(target.size()[0], -1)
                            inputs, targets_a, targets_b, lam = self.mix_data(
                                inputs,
                                target.repeat(1, n_crops).view(-1),
                                self.device, self.hparams.alpha)
                        else:
                            inputs, targets_a, targets_b, lam = self.mix_data(
                                inputs,
                                target.repeat(1, n_crops).view(
                                    -1, len(self.num_tasks)), self.device,
                                self.hparams.alpha)

                    logits = self.forward(inputs)
                    if len(self.hparams.mixtype) > 0:
                        loss_func = self.mixup_criterion(
                            targets_a, targets_b, lam)
                        loss = loss_func(self.criterion, logits)
                    else:
                        if self.hparams.multi_cls:
                            target = target.view(target.size()[0], -1)
                            loss = self.criterion(
                                logits,
                                target.repeat(1, n_crops).view(-1))
                        else:
                            loss = self.criterion(
                                logits,
                                target.repeat(1, n_crops).view(
                                    -1, len(self.num_tasks)))
                else:
                    if len(self.hparams.mixtype) > 0:
                        inputs, targets_a, targets_b, lam = self.mix_data(
                            inputs, target, self.device, self.hparams.alpha)

                    logits = self.forward(inputs)
                    if len(self.hparams.mixtype) > 0:
                        loss_func = self.mixup_criterion(
                            targets_a, targets_b, lam)
                        loss = loss_func(self.criterion, logits)
                    else:
                        loss = self.criterion(logits, target)
            else:
                images_all = torch.cat(inputs, 0)
                logits_all = self.forward(images_all)
                logits_clean, logits_aug1, logits_aug2 = torch.split(
                    logits_all, inputs[0].size(0))

                # Cross-entropy is only computed on clean images
                loss = F.cross_entropy(logits_clean, target)

                p_clean, p_aug1, p_aug2 = F.softmax(
                    logits_clean,
                    dim=1), F.softmax(logits_aug1,
                                      dim=1), F.softmax(logits_aug2, dim=1)

                # Clamp mixture distribution to avoid exploding KL divergence
                p_mixture = torch.clamp((p_clean + p_aug1 + p_aug2) / 3., 1e-7,
                                        1).log()
                loss += 12 * (
                    F.kl_div(p_mixture, p_clean, reduction='batchmean') +
                    F.kl_div(p_mixture, p_aug1, reduction='batchmean') +
                    F.kl_div(p_mixture, p_aug2, reduction='batchmean')) / 3.

            assert not np.isnan(
                loss.item()), 'Model diverged with losses = NaN'

            loss.backward()
            self.optimizer.step()
            summary_train['step'] += 1
            losses.update(loss.item(), target.size(0))

            if summary_train['step'] % self.hparams.log_every == 0:
                time_spent = time.time() - time_now
                time_now = time.time()
                logging.info('Train, '
                             'Epoch : {}, '
                             'Step : {}/{}, '
                             'Loss: {loss.val:.4f} ({loss.avg:.4f}), '
                             'Run Time : {runtime:.2f} sec'.format(
                                 summary_train['epoch'] + 1,
                                 summary_train['step'],
                                 summary_train['total_step'],
                                 loss=losses,
                                 runtime=time_spent))
                print('Train, '
                      'Epoch : {}, '
                      'Step : {}/{}, '
                      'Loss: {loss.val:.4f} ({loss.avg:.4f}), '
                      'Run Time : {runtime:.2f} sec'.format(
                          summary_train['epoch'] + 1,
                          summary_train['step'],
                          summary_train['total_step'],
                          loss=losses,
                          runtime=time_spent))

            if summary_train['step'] % self.hparams.test_every == 0:
                self.validation_end(summary_dev, summary_train, best_dict)

            self.model.train()
            torch.set_grad_enabled(True)

        summary_train['epoch'] += 1
        return summary_train, best_dict
Example #13
0
def train_epoch(epoch, data_loader, model, criterion, optimizer, opt,
                epoch_logger, batch_logger, save_features):
    print('train at epoch {}'.format(epoch))

    model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    accuracies = AverageMeter()

    end_time = time.time()
    for i, (inputs, targets, video_ids) in enumerate(data_loader):
        data_time.update(time.time() - end_time)

        if not opt.no_cuda:
            targets = targets.cuda(async=True)
        inputs = Variable(inputs)
        targets = Variable(targets)

        if opt.save_features:
            model.module.label = video_ids[0] + str(targets.tolist()[0])

        outputs = model(inputs)

        loss = criterion(outputs, targets)
        acc = calculate_accuracy(outputs, targets)

        losses.update(loss.data[0], inputs.size(0))
        accuracies.update(acc, inputs.size(0))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end_time)
        end_time = time.time()

        batch_logger.log({
            'epoch': epoch,
            'batch': i + 1,
            'iter': (epoch - 1) * len(data_loader) + (i + 1),
            'loss': losses.val,
            'acc': accuracies.val,
            'lr': optimizer.param_groups[0]['lr']
        })

        print('Epoch: [{0}][{1}/{2}]\t'
              'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
              'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
              'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
              'Acc {acc.val:.3f} ({acc.avg:.3f})'.format(
                  epoch,
                  i + 1,
                  len(data_loader),
                  batch_time=batch_time,
                  data_time=data_time,
                  loss=losses,
                  acc=accuracies))

    epoch_logger.log({
        'epoch': epoch,
        'loss': losses.avg,
        'acc': accuracies.avg,
        'lr': optimizer.param_groups[0]['lr']
    })

    if epoch % opt.checkpoint == 0:
        save_file_path = os.path.join(opt.result_path,
                                      'save_{}.pth'.format(epoch))
        states = {
            'epoch': epoch + 1,
            'arch': opt.arch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }
        torch.save(states, save_file_path)