示例#1
0
    def validation(self, args, model, testloader, epoch):
        model.eval()
        criterion = nn.CrossEntropyLoss(size_average='mean')

        metrics = Metrics()
        metrics.reset()
        with torch.no_grad():
            for batch_idx, input_tensors in enumerate(testloader):

                input_data, target = input_tensors
                if (args.cuda):
                    input_data = input_data.cuda()
                    target = target.cuda()

                output = model(input_data)

                loss = criterion(output, target)

                correct, total, acc = accuracy(output, target)
                num_samples = batch_idx * args.batch_size + 1

                metrics.update({
                    'correct': correct,
                    'total': total,
                    'loss': loss.item(),
                    'accuracy': acc
                })
                print_stats(args, epoch, num_samples, testloader, metrics)

        print_summary(args, epoch, num_samples, metrics, mode="Validation")
        return metrics
示例#2
0
def train(args, model, trainloader, optimizer, epoch):
    model.train()
    criterion = nn.CrossEntropyLoss(reduction='mean')

    metrics = Metrics('')
    metrics.reset()
    for batch_idx, input_tensors in enumerate(trainloader):
        optimizer.zero_grad()
        input_data, target = input_tensors
        if (args.cuda):
            input_data = input_data.cuda()
            target = target.cuda()

        output = model(input_data)

        loss = criterion(output, target)
        loss.backward()

        optimizer.step()
        correct, total, acc = accuracy(output, target)

        num_samples = batch_idx * args.batch_size + 1
        metrics.update({
            'correct': correct,
            'total': total,
            'loss': loss.item(),
            'accuracy': acc
        })
        print_stats(args, epoch, num_samples, trainloader, metrics)

    print_summary(args, epoch, num_samples, metrics, mode="Training")
    return metrics
示例#3
0
def train(model, args, device, writer, optimizer, data_loader, epoch):

    # Set train mode
    model.train()

    criterion = nn.CrossEntropyLoss(reduction='mean')
    metric_ftns = ['loss', 'correct', 'total', 'accuracy', 'sens', 'ppv']
    metrics = MetricTracker(*[m for m in metric_ftns],
                            writer=writer,
                            mode='train')
    metrics.reset()

    cm = torch.zeros(args.classes, args.classes)

    for batch_idx, input_tensors in enumerate(data_loader):

        input_data, target = input_tensors[0].to(device), input_tensors[1].to(
            device)

        # Forward
        output = model(input_data)
        loss = criterion(output, target)

        correct, total, acc = accuracy(output, target)
        update_confusion_matrix(cm, output, target)
        metrics.update_all_metrics({
            'correct': correct,
            'total': total,
            'loss': loss.item(),
            'accuracy': acc
        })

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Save TB stats
        writer_step = (epoch - 1) * len(data_loader) + batch_idx
        if ((batch_idx + 1) % args.log_interval == 0):

            # Calculate confusion for this bucket
            ppv, sens = update_confusion_calc(cm)
            metrics.update_all_metrics({'sens': sens, 'ppv': ppv})
            cm = torch.zeros(args.classes, args.classes)

            metrics.write_tb(writer_step)

            num_samples = batch_idx * args.batch_size
            print_stats(args, epoch, num_samples, data_loader, metrics)

    return metrics, writer_step
示例#4
0
def validation(args, model, testloader, epoch, writer):
    model.eval()
    criterion = nn.CrossEntropyLoss(reduction='mean')

    metric_ftns = [
        'loss', 'correct', 'total', 'accuracy', 'ppv', 'sensitivity'
    ]
    val_metrics = MetricTracker(*[m for m in metric_ftns],
                                writer=writer,
                                mode='val')
    val_metrics.reset()
    confusion_matrix = torch.zeros(args.class_dict, args.class_dict)
    with torch.no_grad():
        for batch_idx, input_tensors in enumerate(testloader):

            input_data, target = input_tensors
            if (args.cuda):
                input_data = input_data.cuda()
                target = target.cuda()

            output = model(input_data)

            loss = criterion(output, target)

            correct, total, acc = accuracy(output, target)
            num_samples = batch_idx * args.batch_size + 1
            _, pred = torch.max(output, 1)

            num_samples = batch_idx * args.batch_size + 1
            for t, p in zip(target.cpu().view(-1), pred.cpu().view(-1)):
                confusion_matrix[t.long(), p.long()] += 1
            val_metrics.update_all_metrics(
                {
                    'correct': correct,
                    'total': total,
                    'loss': loss.item(),
                    'accuracy': acc
                },
                writer_step=(epoch - 1) * len(testloader) + batch_idx)

    print_summary(args, epoch, num_samples, val_metrics, mode="Validation")
    s = sensitivity(confusion_matrix.numpy())
    ppv = positive_predictive_value(confusion_matrix.numpy())
    print(f" s {s} ,ppv {ppv}")
    val_metrics.update('sensitivity',
                       s,
                       writer_step=(epoch - 1) * len(testloader) + batch_idx)
    val_metrics.update('ppv',
                       ppv,
                       writer_step=(epoch - 1) * len(testloader) + batch_idx)
    print('Confusion Matrix\n{}'.format(confusion_matrix.cpu().numpy()))
    return val_metrics, confusion_matrix
示例#5
0
def train(args, model, trainloader, optimizer, epoch):

    start_time = time.time()
    model.train()

    train_metrics = MetricTracker(*[m for m in METRICS_TRACKED], mode='train')
    w2 = torch.Tensor([1.0, 1.0, 1.5])

    if (args.cuda):
        model.cuda()
        w2 = w2.cuda()

    train_metrics.reset()
    # JUST FOR CHECK
    counter_batches = 0
    counter_covid = 0

    for batch_idx, input_tensors in enumerate(trainloader):
        optimizer.zero_grad()
        input_data, target = input_tensors
        counter_batches += 1

        if (args.cuda):
            input_data = input_data.cuda()
            target = target.cuda()

        output = model(input_data)

        loss, counter = weighted_loss(output, target, w2)
        counter_covid += counter
        loss.backward()

        optimizer.step()
        correct, total, acc = accuracy(output, target)
        precision_mean, recall_mean = precision_score(output, target)

        num_samples = batch_idx * args.batch_size + 1
        train_metrics.update_all_metrics(
            {
                'correct': correct,
                'total': total,
                'loss': loss.item(),
                'accuracy': acc,
                'precision_mean': precision_mean,
                'recall_mean': recall_mean
            },
            writer_step=(epoch - 1) * len(trainloader) + batch_idx)
        print_stats(args, epoch, num_samples, trainloader, train_metrics)
    print("--- %s seconds ---" % (time.time() - start_time))
    print_summary(args, epoch, num_samples, train_metrics, mode="Training")
    return train_metrics
示例#6
0
def train(args, model, trainloader, optimizer, epoch, class_weight):
    model.train()
    criterion = nn.CrossEntropyLoss(weight=class_weight, reduction='mean')

    metrics = Metrics('')
    metrics.reset()
    #-------------------------------------------------------
    #Esto es para congelar las capas de la red preentrenada
    #for m in model.modules():
    #    if isinstance(m, nn.BatchNorm2d):
    #        m.train()
    #        m.weight.requires_grad = False
    #        m.bias.requires_grad = False
    #-----------------------------------------------------

    for batch_idx, input_tensors in enumerate(trainloader):
        optimizer.zero_grad()
        input_data, target = input_tensors
        if (args.cuda):
            input_data = input_data.cuda()
            target = target.cuda()
        #print(input_data.shape)
        output = model(input_data)
        #print(output.shape)
        #print(target.shape)
        #loss = focal_loss(output, target)
        if args.model == 'CovidNet_DenseNet':
            output = output[-1]

        loss = crossentropy_loss(output, target, weight=class_weight)
        loss.backward()
        optimizer.step()
        correct, total, acc = accuracy(output, target)

        num_samples = batch_idx * args.batch_size + 1
        _, output_class = output.max(1)
        #print(output_class)
        #print(target)
        bacc = balanced_accuracy_score(target.cpu().detach().numpy(),
                                       output_class.cpu().detach().numpy())
        metrics.update({
            'correct': correct,
            'total': total,
            'loss': loss.item(),
            'accuracy': acc,
            'bacc': bacc
        })
        print_stats(args, epoch, num_samples, trainloader, metrics)

    print_summary(args, epoch, num_samples, metrics, mode="Training")
    return metrics
示例#7
0
def validation(args, model, testloader, epoch):

    model.eval()

    val_metrics = MetricTracker(*[m for m in METRICS_TRACKED], mode='val')
    val_metrics.reset()
    w2 = torch.Tensor([1.0, 1.0,
                       1.5])  #w_full = torch.Tensor([1.456,1.0,15.71])

    if (args.cuda):
        w2 = w2.cuda()

    confusion_matrix = torch.zeros(args.classes, args.classes)

    with torch.no_grad():
        for batch_idx, input_tensors in enumerate(testloader):

            input_data, target = input_tensors

            if (args.cuda):
                input_data = input_data.cuda()
                target = target.cuda()

            output = model(input_data)

            loss, counter = weighted_loss(output, target, w2)
            correct, total, acc = accuracy(output, target)
            precision_mean, recall_mean = precision_score(output, target)

            num_samples = batch_idx * args.batch_size + 1
            _, preds = torch.max(output, 1)

            for t, p in zip(target.cpu().view(-1), preds.cpu().view(-1)):
                confusion_matrix[t.long(), p.long()] += 1
            val_metrics.update_all_metrics(
                {
                    'correct': correct,
                    'total': total,
                    'loss': loss.item(),
                    'accuracy': acc,
                    'precision_mean': precision_mean,
                    'recall_mean': recall_mean
                },
                writer_step=(epoch - 1) * len(testloader) + batch_idx)

    print_summary(args, epoch, num_samples, val_metrics, mode="Validation")
    print('Confusion Matrix\n {}'.format(confusion_matrix.cpu().numpy()))

    return val_metrics, confusion_matrix
示例#8
0
def validation(args, model, testloader, epoch, class_weight):
    model.eval()

    #-------------------------------------------------------
    #Esto es para congelar las capas de la red preentrenada
    #for m in model.modules():
    #    if isinstance(m, nn.BatchNorm2d):
    #        m.train()
    #        m.weight.requires_grad = False
    #        m.bias.requires_grad = False
    #-----------------------------------------------------

    criterion = nn.CrossEntropyLoss(size_average='mean')

    metrics = Metrics('')
    metrics.reset()
    confusion_matrix = torch.zeros(args.classes, args.classes)
    with torch.no_grad():
        for batch_idx, input_tensors in enumerate(testloader):

            input_data, target = input_tensors
            if (args.cuda):
                input_data = input_data.cuda()
                target = target.cuda()
            #print(input_data.shape)
            output = model(input_data)
            if args.model == 'CovidNet_DenseNet':
                output = output[-1]
            #loss = focal_loss(output, target)
            loss = crossentropy_loss(output, target, weight=class_weight)

            correct, total, acc = accuracy(output, target)
            num_samples = batch_idx * args.batch_size + 1
            _, preds = torch.max(output, 1)
            bacc = balanced_accuracy_score(target.cpu().detach().numpy(),
                                           preds.cpu().detach().numpy())
            for t, p in zip(target.cpu().view(-1), preds.cpu().view(-1)):
                confusion_matrix[t.long(), p.long()] += 1
            metrics.update({
                'correct': correct,
                'total': total,
                'loss': loss.item(),
                'accuracy': acc,
                'bacc': bacc
            })
            #print_stats(args, epoch, num_samples, testloader, metrics)

    print_summary(args, epoch, num_samples, metrics, mode="Validation")
    return metrics, confusion_matrix
示例#9
0
def train(args, model, trainloader, optimizer, epoch, writer, log):
    model.train()
    criterion = nn.CrossEntropyLoss(reduction='mean')

    metric_ftns = [
        'loss', 'correct', 'total', 'accuracy', 'ppv', 'sensitivity'
    ]
    train_metrics = MetricTracker(*[m for m in metric_ftns],
                                  writer=writer,
                                  mode='train')
    train_metrics.reset()
    confusion_matrix = torch.zeros(args.class_dict, args.class_dict)

    for batch_idx, input_tensors in enumerate(trainloader):
        optimizer.zero_grad()
        input_data, target = input_tensors
        if (args.cuda):
            input_data = input_data.cuda()
            target = target.cuda()

        output = model(input_data)

        loss = criterion(output, target)
        loss.backward()

        optimizer.step()
        correct, total, acc = accuracy(output, target)
        pred = torch.argmax(output, dim=1)

        num_samples = batch_idx * args.batch_size + 1
        train_metrics.update_all_metrics(
            {
                'correct': correct,
                'total': total,
                'loss': loss.item(),
                'accuracy': acc
            },
            writer_step=(epoch - 1) * len(trainloader) + batch_idx)
        print_stats(args, epoch, num_samples, trainloader, train_metrics)
        for t, p in zip(target.cpu().view(-1), pred.cpu().view(-1)):
            confusion_matrix[t.long(), p.long()] += 1
    s = sensitivity(confusion_matrix.numpy())
    ppv = positive_predictive_value(confusion_matrix.numpy())
    print(f" s {s} ,ppv {ppv}")
    # train_metrics.update('sensitivity', s, writer_step=(epoch - 1) * len(trainloader) + batch_idx)
    # train_metrics.update('ppv', ppv, writer_step=(epoch - 1) * len(trainloader) + batch_idx)
    print_summary(args, epoch, num_samples, train_metrics, mode="Training")
    return train_metrics
示例#10
0
def val(args, model, data_loader, epoch, writer, device):

    model.eval()

    criterion = nn.CrossEntropyLoss(reduction='mean')
    metric_ftns = ['loss', 'correct', 'total', 'accuracy', 'ppv', 'sens']
    metrics = MetricTracker(*[m for m in metric_ftns],
                            writer=writer,
                            mode='val')
    metrics.reset()

    cm = torch.zeros(args.classes, args.classes)

    with torch.no_grad():
        for batch_idx, input_tensors in enumerate(data_loader):
            torch.cuda.empty_cache()
            input_data, target = input_tensors[0].to(
                device), input_tensors[1].to(device)

            # Forward
            output = model(input_data)
            loss = criterion(output, target)

            correct, total, acc = accuracy(output, target)
            update_confusion_matrix(cm, output, target)

            # Update the metrics record
            metrics.update_all_metrics({
                'correct': correct,
                'total': total,
                'loss': loss.item(),
                'accuracy': acc
            })

        ppv, sens = update_confusion_calc(cm)
        metrics.update_all_metrics({'sens': sens, 'ppv': ppv})

    return metrics, cm
示例#11
0
def validation(args, model, testloader, epoch):
    model.eval()
    criterion = nn.CrossEntropyLoss(size_average='mean')

    metrics = Metrics('')
    metrics.reset()
    with torch.no_grad():
        for batch_idx, input_tensors in enumerate(testloader):

            input_data, target = input_tensors
            if (args.cuda):
                input_data = input_data.cuda()
                target = target.cuda()

            output = model(input_data)

            #loss = criterion(output, target)
            loss = focal_loss(output, target)

            correct, total, acc = accuracy(output, target)
            num_samples = batch_idx * args.batch_size + 1

            _, preds = torch.max(output, 1)
            for t, p in zip(target.cpu().view(-1), preds.cpu().view(-1)):
                confusion_matrix[t.long(), p.long()] += 1
            metrics.update({
                'correct': correct,
                'total': total,
                'loss': loss.item(),
                'accuracy': acc
            })

            #metrics.update({'correct': correct, 'total': total, 'loss': loss.item(), 'accuracy': acc})
            #print_stats(args, epoch, num_samples, testloader, metrics)

    print_summary(args, epoch, num_samples, metrics, mode="Validation")
    return metrics
示例#12
0
    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :return: A log that contains information about validation
        """

        batch_time = AverageMeter("batch_time")
        losses = AverageMeter("losses")
        losses_kws = AverageMeter("losses_kws")
        losses_dec = AverageMeter("losses_dec")
        losses_loc = AverageMeter("losses_loc")
        top1 = AverageMeter("top1")
        myprec = AverageMeter("myprec")
        myrec = AverageMeter("myrec")
        TarRank = []
        NonRank = []
        labels = []
        scores = []

        for k in range(0, self.num_words):
          TarRank.append([])
          NonRank.append([])

        self.model.eval()
        end = time.time()
        pbar = tqdm(total=len(self.val_dataloader))

        for i, lstVwidx in enumerate(self.val_dataloader):
          count = []
          positives = 0
          for k in range(0,len(lstVwidx)):
            for l in lstVwidx[k][1]:
              if l != -1:
                positives +=1
                if l not in count:
                  count.append(l)
          if len(count)>1:   
            input, lens, widx, target, localization_mask,localization_mask_boundaries= transform_batch(lstVwidx, self.val_dataset.get_word_mask(),
                self.num_words, self.config)
            labels = np.concatenate((labels,target), axis=0)
            targetInt = target.astype('int32')
            target = torch.from_numpy(target).cuda(async=True)
            input = torch.from_numpy(input).float().cuda(async=True)
            localization_mask = torch.from_numpy(localization_mask).float().cuda(async=True)
            widx = torch.from_numpy(widx).cuda(async=True)
            input_var = Variable(input)
            target_var = Variable(target.view(-1,1)).float()
            grapheme = []
            phoneme = []
            p_lens = []
            for w in widx:
                p_lens.append(len(self.val_dataset.get_GP(w)[0]))
                grapheme.append(self.val_dataset.get_GP(w)[0])
                phoneme.append(self.val_dataset.get_GP(w)[1])
            p_lens = np.asarray(p_lens) 
            if self.g2p:
              graphemeTensor = Variable(self.val_dataset.grapheme2tensor_g2p(grapheme)).cuda()
              phonemeTensor = Variable(self.val_dataset.phoneme2tensor_g2p(phoneme)).cuda()
              preds = self.model(vis_feat_lens=lens, p_lengths=p_lens, phonemes=phonemeTensor[:-1].detach(),
                graphemes=graphemeTensor.detach(), vis_feats=input_var, use_BE_localiser
                =self.use_BE_localiser, epoch=epoch, config=self.config)
              tdec = phonemeTensor[1:]
            else:
              graphemeTensor = Variable(self.val_dataset.grapheme2tensor(grapheme)).cuda()
              phonemeTensor = Variable(self.val_dataset.phoneme2tensor(phoneme)).cuda()
              preds = self.model(vis_feat_lens=lens, p_lengths=p_lens, phonemes=phonemeTensor.detach(),
                graphemes=graphemeTensor[:-1].detach(), vis_feats=input_var, use_BE_localiser
                =self.use_BE_localiser, epoch=epoch, config=self.config) #changed vis_feat_lens from lens to p_lens
              tdec = graphemeTensor[1:]
            scores = np.concatenate((scores, preds['keyword_prob'].view(1, len(target)).detach().cpu().numpy()[0]), axis=0) 
            loss_dec = module_loss.nll_loss(preds["odec"].view(preds["odec"].size(0)*preds["odec"].size(1),-1), tdec.view(tdec.size(0)*tdec.size(1)))
            loss_kws = self.BCE_loss(preds["max_logit"], target_var )
            if self.loc_weight_loss:
              localization_mask = localization_mask*-1000000
              o_logits = localization_mask + preds["o_logits"].squeeze(-1)
              max_localised = o_logits.max(1)[0]
              loss_loc = self.BCE_loss(max_localised.unsqueeze(1), target_var)
              loss_total = self.kws_weight_loss*loss_kws + self.dec_weight_loss*loss_dec + self.loc_weight_loss*loss_loc
            else:
              loss_total = self.kws_weight_loss*loss_kws + self.dec_weight_loss*loss_dec
              loss_loc = loss_total
            PTrue = preds["keyword_prob"]
            PFalseTrue = torch.cat((PTrue.add(-1).mul(-1),PTrue),1)
            prec1 = module_met.accuracy(PFalseTrue.data, target, topk=(1,))[0]
            PR = module_met.PrecRec(PFalseTrue.data, target, topk=(1,))
            losses.update(loss_total.item(), input.size(0))
            losses_kws.update(loss_kws.item(), input.size(0))
            losses_dec.update(loss_dec.item(), input.size(0))
            losses_loc.update(loss_loc.item(), input.size(0))
            top1.update(prec1[0], input.size(0))
            myprec.update(PR[0], (PFalseTrue.data[:,1]>0.5).sum())
            myrec.update(PR[1], target.sum())
          pbar.update(1)
        self.writer.set_step(epoch, 'valid')
        self.writer.add_scalar("loss_kws", losses_kws.avg)
        self.writer.add_scalar("loss_loc", losses_loc.avg)
        self.writer.add_scalar("loss_dec", losses_dec.avg)
        self.writer.add_scalar("acc", top1.avg)
        batch_time.update(time.time() - end)
        end = time.time()
        
        print("Prec@1 {top1.avg:.3f}, Precision {myprec.avg:.3f}, Recall {myrec.avg:.3f}, Loss_kws {loss_kws.avg:.4f}, Loss_dec {loss_dec.avg:.4f},Loss_loc {loss_loc.avg:.4f}".format(top1=top1, myprec=myprec, myrec=myrec, loss_kws=losses_kws, loss_dec=losses_dec, loss_loc=losses_loc))
        self.logger.info("Prec@1 {top1.avg:.3f}, Precision {myprec.avg:.3f}, Recall {myrec.avg:.3f}, Loss_kws {loss_kws.avg:.4f}, Loss_dec {loss_dec.avg:.4f}".format(top1=top1, myprec=myprec, myrec=myrec, loss_kws=losses_kws, loss_dec=losses_dec))
        pbar.close()
示例#13
0
    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Current training epoch.
        :return: A log that contains all information you want to save.

        """
        if epoch > 40: 
          self.len_epoch = len(self.train_dataloader)

        if epoch > 60 and self.config["arch"]["args"]["rnn2"]==True:
          assert self.dec_weight_loss == 1
          self.dec_weight_loss = 0.1

        if epoch >= self.start_BEloc_epoch:
          self.use_BE_localiser = True

        self.model.train()
        batch_time = AverageMeter("batch_time")
        data_time = AverageMeter("data_time")
        losses_kws = AverageMeter("losses_kws")
        losses_dec = AverageMeter("losses_dec")
        losses_loc = AverageMeter("losses_loc")
        top1 = AverageMeter("top1")
        end = time.time()
  
        pbar = tqdm(total=len(self.train_dataloader))
        for i, lstVwidx in enumerate(self.train_dataloader):
          count = []
          positives = 0
          for k in range(0,len(lstVwidx)):
            for l in lstVwidx[k][1]:
              if l != -1:
                positives +=1
                if l not in count:
                  count.append(l)
          if len(count)>1:   
            input, lens, widx, target, localization_mask,localization_mask_boundaries= transform_batch(lstVwidx, self.train_dataset.get_word_mask(),
                self.num_words, self.config)
            target = torch.from_numpy(target).cuda(async=True)
            input = torch.from_numpy(input).float().cuda(async=True)
            localization_mask = torch.from_numpy(localization_mask).float().cuda(async=True)
            widx = torch.from_numpy(widx).cuda(async=True)
            grapheme = []
            phoneme = []
            p_lens = []
            for w in widx:
              p_lens.append(len(self.train_dataset.get_GP(w)[0]))
              grapheme.append(self.train_dataset.get_GP(w)[0])
              phoneme.append(self.train_dataset.get_GP(w)[1])
            input_var = Variable(input)
            p_lens = np.asarray(p_lens)
            target_var = Variable(target.view(-1,1)).float()
            if self.g2p:
              graphemeTensor = Variable(self.train_dataset.grapheme2tensor_g2p(grapheme)).cuda()
              phonemeTensor = Variable(self.train_dataset.phoneme2tensor_g2p(phoneme)).cuda()
              preds = self.model(vis_feat_lens=lens, p_lengths=p_lens, phonemes=phonemeTensor[:-1].detach(),
                graphemes=graphemeTensor.detach(), vis_feats=input_var, use_BE_localiser
                =self.use_BE_localiser, epoch=epoch, config=self.config)
              tdec = phonemeTensor[1:]
            else:
              graphemeTensor = Variable(self.train_dataset.grapheme2tensor(grapheme)).cuda()
              phonemeTensor = Variable(self.train_dataset.phoneme2tensor(phoneme)).cuda()
              preds = self.model(vis_feat_lens=lens, p_lengths=p_lens, phonemes=phonemeTensor.detach(),
                graphemes=graphemeTensor[:-1].detach(), vis_feats=input_var, use_BE_localiser
                =self.use_BE_localiser, epoch=epoch, config=self.config) #changed vis_feat_lens from lens to p_lens
              tdec = graphemeTensor[1:]
            loss_dec = module_loss.nll_loss(preds["odec"].view(preds["odec"].size(0)*preds["odec"].size(1),-1),
              tdec.view(tdec.size(0)*tdec.size(1)))
            loss_kws = self.BCE_loss(preds["max_logit"], target_var)
            if self.loc_weight_loss:
              localization_mask = localization_mask*-1000000
              o_logits = localization_mask + preds["o_logits"].squeeze(-1)
              max_localised = o_logits.max(1)[0]
              loss_loc = self.BCE_loss(max_localised.unsqueeze(1), target_var)
              loss_total = self.kws_weight_loss*loss_kws + self.dec_weight_loss*loss_dec + self.loc_weight_loss*loss_loc
            else: 
              loss_total = self.kws_weight_loss*loss_kws+ self.dec_weight_loss*loss_dec
              loss_loc = loss_total
            PTrue = preds["keyword_prob"]
            PFalseTrue = torch.cat((PTrue.add(-1).mul(-1),PTrue),1)
            prec1 = module_met.accuracy(PFalseTrue.data, target, topk=(1,))[0]
            losses_kws.update(loss_kws.item(), input.size(0))
            losses_dec.update(loss_dec.item(), input.size(0))
            losses_loc.update(loss_loc.item(), input.size(0))
            top1.update(prec1[0], input.size(0))
            self.optimizer.zero_grad()
            loss_total.backward()
            clip_grad_norm(self.model.parameters(), self.clip, 'inf') #this might not work
            self.optimizer.step()
            batch_time.update(time.time() - end)
            end = time.time()

  
          pbar.update(1)
        self.writer.set_step(epoch)
        self.writer.add_scalar("loss_kws", losses_kws.avg)
        self.writer.add_scalar("loss_loc", losses_loc.avg)
        self.writer.add_scalar("loss_dec", losses_dec.avg)
        self.writer.add_scalar("acc", top1.avg)
        print('Epoch: [{0}][{1}/{2}]\t'
              'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) \t'
              'Data {data_time.val:.3f} ({data_time.avg:.3f}) \t'
              'Loss_kws {loss_kws.val:.4f} ({loss_kws.avg:.4f})\t'
              'Loss_loc{loss_loc.val:.4f} ({loss_loc.avg:.4f})\t'
              'Loss_dec {loss_dec.val:.4f} ({loss_dec.avg:.4f})\t'
              'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
              epoch, i, len(self.train_dataloader), batch_time=batch_time, data_time= data_time,
              loss_kws=losses_kws, loss_loc=losses_loc, loss_dec=losses_dec, top1=top1))
        self.logger.info('Epoch: [{0}][{1}/{2}]\t'
              'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) \t'
              'Data {data_time.val:.3f} ({data_time.avg:.3f}) \t'
              'Loss_kws {loss_kws.val:.4f} ({loss_kws.avg:.4f})\t'
              'Loss_loc {loss_loc.val:.4f} ({loss_loc.avg:.4f})\t'
              'Loss_dec {loss_dec.val:.4f} ({loss_dec.avg:.4f})\t'
              'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
              epoch, i, len(self.train_dataloader), batch_time=batch_time, data_time= data_time,
              loss_kws=losses_kws, loss_loc=losses_loc, loss_dec=losses_dec, top1=top1))

        pbar.close()
         
        if self.do_validation:
          self._valid_epoch(epoch) 

        if self.lr_scheduler is not None:
          self.lr_scheduler.step()
示例#14
0
def main(args):
    config = parse_config(args.field, args.config)

    # exp params
    seed = int(config['EXP']['seed'])
    exp_name = config['EXP']['exp_name']
    batch_size = int(config['EXP']['batch_size'])
    model_name = config['EXP']['model']
    epochs = int(config['EXP']['epochs'])
    lr = float(config['EXP']['lr'])
    val_freq = int(config['EXP']['val_freq'])
    worker = int(config['EXP']['worker'])
    gpus = config['EXP']['gpus']
    unsuper = config['EXP']['unsuper']

    # dataset params
    dataset = config['DATASET']['dataset']
    root = config['DATASET']['root']
    num_imgs_per_cat = config['DATASET']['num_imgs_per_cat']
    d_type = config['DATASET']['type']

    # gpu
    os.environ["CUDA_VISIBLE_DEVICES"] = gpus

    # model params
    optim_name = config['MODEL']['optim']
    scheduler_name = config['MODEL']['scheduler']
    criterion_name = config['MODEL']['criterion']
    transfer = config['MODEL']['transfer']
    block_num = config['MODEL']['block_op']
    no_head = config['MODEL']['no_head']
    cutmix_alpha = float(config['MODEL']['cutmix_alpha'])
    cutmix_prob = float(config['MODEL']['cutmix_prob'])
    labelsmooth = config['MODEL']['labelsmooth']

    # fix seed
    np.random.seed(seed)
    torch.manual_seed(seed)

    if unsuper == 'true':
        unsuper = True
    else:
        unsuper = False
    # make dataloader
    if unsuper:
        rot_preprocess = rotpreprocess()
    train_preprocess = trainpreprocess(config['DATASET'])
    val_preprocess = valpreprocess()
    if dataset == 'cifar10':
        trainset = CIFAR10(root=root,
                           train=True,
                           download=True,
                           transform=train_preprocess,
                           unsuper=unsuper)
        valset = CIFAR10(root=root,
                         train=False,
                         download=True,
                         transform=val_preprocess,
                         unsuper=unsuper)
        num_classes = len(trainset.classes)
    elif dataset == 'fasion':
        if unsuper:
            trainset = SimpleImageLoader(root=root,
                                         split='unlabel',
                                         transform=rot_preprocess,
                                         unsuper=unsuper)
        else:
            trainset = SimpleImageLoader(root=root,
                                         split='train',
                                         transform=train_preprocess,
                                         num_imgs_per_cat=num_imgs_per_cat)
        valset = SimpleImageLoader(root=root,
                                   split='validation',
                                   transform=val_preprocess,
                                   unsuper=unsuper)
        num_classes = trainset.classnumber
    else:
        raise ValueError('make sure dataset is cifar 10, etc')
    '''if unsuper:
        batch_size = batch_size//4'''
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=worker)
    valloader = torch.utils.data.DataLoader(valset,
                                            batch_size=batch_size,
                                            shuffle=False,
                                            num_workers=worker)

    # get model
    if model_name == 'efficientnet':
        phi = int(config['MODEL']['depth'])
        print(transfer, block_num, num_classes, num_imgs_per_cat)
        if no_head == 'true':
            model = efficientnet(phi=phi,
                                 num_classes=num_classes,
                                 transfer=transfer,
                                 block_num=block_num,
                                 no_head=True)
        else:
            model = efficientnet(phi=phi,
                                 num_classes=num_classes,
                                 transfer=transfer,
                                 block_num=block_num)
    elif model_name == 'resnet':
        depth = int(config['MODEL']['depth'])
        model = resnet(depth=depth, num_classes=num_classes)
    else:
        raise ValueError('no supported model name')

    print(model)
    model = model.cuda()
    #model = torch.nn.DataParallel(model).cuda()

    # set loss & optimizer & scheduler
    if criterion_name == 'crossentropy':
        criterion = torch.nn.CrossEntropyLoss()
    else:
        raise ValueError('no supported loss function name')

    if optim_name == 'adam':
        #optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
        optimizer = optim.Adam(model.parameters(), lr=lr)
    elif optim_name == 'rangerlars':
        #optimizer = RangerLars(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
        optimizer = RangerLars(model.parameters(), lr=lr)
    else:
        raise ValueError('no supported optimizer name')

    if scheduler_name == 'reducelr':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                         patience=3,
                                                         verbose=True)
    elif scheduler_name == 'cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                               T_max=epochs,
                                                               eta_min=0.)
    elif scheduler_name == 'cyclic':
        scheduler = CyclicLR(optimizer,
                             base_lr=lr * 0.3,
                             max_lr=lr,
                             step_size_up=10,
                             cycle_momentum=False)
    else:
        raise ValueError('no supported scheduler name')

    model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
    model.training = True
    # save dir
    try:
        os.mkdir('saved/models/{}'.format(exp_name))
        os.mkdir('saved/logs/{}'.format(exp_name))
    except:
        raise ValueError('existed exp name : {}'.format(exp_name))

    writer = SummaryWriter("saved/logs/{}".format(exp_name))

    #save config
    with open('saved/models/{}/config.ini'.format(exp_name),
              'w') as configfile:
        config_saved = configparser.ConfigParser(allow_no_value=True)
        config_saved.read(args.config)
        config_saved.write(configfile)

    # training
    iter = 0
    best_acc = 0
    for epoch_num in range(epochs):

        # -------------------------------- train model ---------------------------- #
        model.train()
        epoch_loss = []
        for iter_num, data in enumerate(trainloader):
            #break
            optimizer.zero_grad()
            if unsuper:
                '''random_index = [0,2,4,6] 
                random.shuffle(random_index)
                image = torch.cat([data[random_index[0]], data[random_index[1]], data[random_index[2]], data[random_index[3]]], dim=0)
                label = torch.cat([data[random_index[0]+1], data[random_index[1]+1], data[random_index[2]+1], data[random_index[3]+1]], dim=0)
                image = image.cuda()
                label = label.cuda()'''
                image, label = data[0], data[1]
                image = image.cuda()
                label = label.cuda()
            else:
                image, label = data[0], data[1]
                image = image.cuda()
                label = label.cuda()
            pred = model(image)
            #print(pred, label)
            loss = criterion(pred, label)

            #print(label)
            '''if cutmix_alpha <= 0.0 or np.random.rand(1) > cutmix_prob:
                pred = model(image)
                loss = criterion(pred, label)
            else:
                # CutMix : generate mixed sample
                lam = np.random.beta(cutmix_alpha, cutmix_alpha)
                rand_index = torch.randperm(image.size()[0]).cuda()
                target_a = label
                target_b = label[rand_index]
                bbx1, bby1, bbx2, bby2 = rand_bbox(image.size(), lam)
                image[:, :, bbx1:bbx2, bby1:bby2] = image[rand_index, :, bbx1:bbx2, bby1:bby2]

                pred = model(image)
                lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (image.size()[-1] * image.size()[-2]))
                loss = criterion(pred, target_a) * lam + criterion(pred, target_b) * (1. - lam)'''

            if bool(loss == 0):
                continue

            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            #loss.backward()
            #torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
            optimizer.step()

            epoch_loss.append(float(loss))

            writer.add_scalar('Loss/train', loss, iter)
            iter += 1
            print('Epoch: {} | Iteration: {} | Running loss: {:1.5f}'.format(
                epoch_num, iter_num, np.mean(np.array(epoch_loss))))

        # -------------------------------- validate model ---------------------------- #
        model.eval()
        val_loss = []
        top1_list = []
        top5_list = []
        with torch.no_grad():
            for iter_num, data in enumerate(valloader):

                if unsuper:
                    '''image = torch.cat([data[0], data[2], data[4], data[6]], dim=0)
                    label = torch.cat([data[1], data[3], data[5], data[7]], dim=0)
                    image = image.cuda()
                    label = label.cuda()'''
                    image, label = data[0], data[1]
                    image = image.cuda()
                    label = label.cuda()
                else:
                    image, label = data[0], data[1]
                    image = image.cuda()
                    label = label.cuda()
                #print(image, label)
                pred = model(image)
                loss = criterion(pred, label)

                if unsuper:
                    top1, _ = accuracy(pred, label, (1, 2))
                else:
                    top1, top5 = accuracy(pred, label, (1, 5))

                top1_list.append(top1.item())
                print('Val_Epoch: {} | iter : {} | top1-acc : {:1.5f}'.format(
                    epoch_num, iter_num, top1))

                if not unsuper:
                    top5_list.append(top5.item())
                #val_correct += (torch.max(output, 1)[1] == label).sum().item()
                val_loss.append(float(loss))

        #acc = val_correct / len(valset)
        val_loss_mean = np.mean(np.array(val_loss))
        top1 = np.mean(np.array(top1_list))
        if not unsuper:
            top5 = np.mean(np.array(top5_list))

        if scheduler_name == 'reducelr':
            scheduler.step(val_loss_mean)
        else:
            scheduler.step()

        if unsuper:
            print('Epoch: {} | loss : {:1.5f} | top1-acc : {:1.5f} '.format(
                epoch_num, val_loss_mean, top1))
        else:
            print(
                'Epoch: {} | loss : {:1.5f} | top1-acc : {:1.5f} | top5-acc : {:1.5f}'
                .format(epoch_num, val_loss_mean, top1, top5))

        if best_acc < top1:
            best_acc = top1

        writer.add_scalar('Acc/top1-acc', top1, epoch_num)
        if not unsuper:
            writer.add_scalar('Acc/top5-acc', top5, epoch_num)
        writer.add_scalar('Loss/val', val_loss_mean, epoch_num)
        writer.add_scalar('satus/lr', optimizer.param_groups[0]['lr'],
                          epoch_num)
        torch.save(model.module.state_dict(),
                   'saved/models/{}/model_{}.pt'.format(exp_name, epoch_num))

        torch.save(
            model.module.state_dict(),
            'saved/models/{}/model_best_acc_{}.pt'.format(exp_name, best_acc))