Beispiel #1
0
    def train(self, epoch, max_epoch, trainloader, fixbase_epoch=0, open_layers=None, print_freq=10):
        losses = AverageMeter()
        accs = AverageMeter()
        batch_time = AverageMeter()
        data_time = AverageMeter()

        self.model.train()
        if (epoch+1)<=fixbase_epoch and open_layers is not None:
            print('* Only train {} (epoch: {}/{})'.format(open_layers, epoch+1, fixbase_epoch))
            open_specified_layers(self.model, open_layers)
        else:
            open_all_layers(self.model)

        end = time.time()
        for batch_idx, data in enumerate(trainloader):
            data_time.update(time.time() - end)

            imgs, pids = self._parse_data_for_train(data)
            if self.use_gpu:
                imgs = imgs.cuda()
                pids = pids.cuda()
            
            self.optimizer.zero_grad()
            outputs = self.model(imgs)
            loss = self._compute_loss(self.criterion, outputs, pids)
            loss.backward()
            self.optimizer.step()

            batch_time.update(time.time() - end)

            losses.update(loss.item(), pids.size(0))
            accs.update(metrics.accuracy(outputs, pids)[0].item())

            if (batch_idx+1) % print_freq == 0:
                # estimate remaining time
                num_batches = len(trainloader)
                eta_seconds = batch_time.avg * (num_batches-(batch_idx+1) + (max_epoch-(epoch+1))*num_batches)
                eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
                print('Epoch: [{0}/{1}][{2}/{3}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Acc {acc.val:.2f} ({acc.avg:.2f})\t'
                      'Lr {lr:.6f}\t'
                      'Eta {eta}'.format(
                      epoch+1, max_epoch, batch_idx+1, len(trainloader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      acc=accs,
                      lr=self.optimizer.param_groups[0]['lr'],
                      eta=eta_str
                    )
                )
            
            end = time.time()

        if self.scheduler is not None:
            self.scheduler.step()
Beispiel #2
0
def train(epoch, model, criterion, optimizer, scheduler, trainloader, use_gpu):
    losses = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    model.train()

    if (epoch + 1) <= args.fixbase_epoch and args.open_layers is not None:
        print('* Only train {} (epoch: {}/{})'.format(args.open_layers,
                                                      epoch + 1,
                                                      args.fixbase_epoch))
        open_specified_layers(model, args.open_layers)
    else:
        open_all_layers(model)

    end = time.time()
    for batch_idx, data in enumerate(trainloader):
        data_time.update(time.time() - end)

        imgs, attrs = data[0], data[1]
        if use_gpu:
            imgs = imgs.cuda()
            attrs = attrs.cuda()

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, attrs)
        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end)

        losses.update(loss.item(), imgs.size(0))

        if (batch_idx + 1) % args.print_freq == 0:
            # estimate remaining time
            num_batches = len(trainloader)
            eta_seconds = batch_time.avg * (num_batches - (batch_idx + 1) +
                                            (args.max_epoch -
                                             (epoch + 1)) * num_batches)
            eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
            print('Epoch: [{0}/{1}][{2}/{3}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Lr {lr:.6f}\t'
                  'Eta {eta}'.format(epoch + 1,
                                     args.max_epoch,
                                     batch_idx + 1,
                                     len(trainloader),
                                     batch_time=batch_time,
                                     data_time=data_time,
                                     loss=losses,
                                     lr=optimizer.param_groups[0]['lr'],
                                     eta=eta_str))

        end = time.time()

    scheduler.step()
Beispiel #3
0
    def train(self, epoch, trainloader, fixbase=False, open_layers=None, print_freq=10):
        """Trains the model for one epoch on source datasets using softmax loss.

        Args:
            epoch (int): current epoch.
            trainloader (Dataloader): training dataloader.
            fixbase (bool, optional): whether to fix base layers. Default is False.
            open_layers (str or list, optional): layers open for training.
            print_freq (int, optional): print frequency. Default is 10.
        """
        losses = AverageMeter()
        accs = AverageMeter()
        batch_time = AverageMeter()
        data_time = AverageMeter()

        self.model.train()

        if fixbase and (open_layers is not None):
            open_specified_layers(self.model, open_layers)
        else:
            open_all_layers(self.model)

        end = time.time()
        for batch_idx, data in enumerate(trainloader):
            data_time.update(time.time() - end)

            imgs, pids = self._parse_data_for_train(data)
            if self.use_gpu:
                imgs = imgs.cuda()
                pids = pids.cuda()
            
            self.optimizer.zero_grad()
            outputs = self.model(imgs)
            loss = self._compute_loss(self.criterion, outputs, pids)
            loss.backward()
            self.optimizer.step()

            batch_time.update(time.time() - end)

            losses.update(loss.item(), pids.size(0))
            accs.update(metrics.accuracy(outputs, pids)[0].item())

            if (batch_idx+1) % print_freq==0:
                print('Epoch: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Acc {acc.val:.2f} ({acc.avg:.2f})\t'.format(
                      epoch + 1, batch_idx + 1, len(trainloader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      acc=accs))
            
            end = time.time()

        if (self.scheduler is not None) and (not fixbase):
            self.scheduler.step()
    def two_stepped_transfer_learning(self, epoch, fixbase_epoch, open_layers):
        """Two-stepped transfer learning.

        The idea is to freeze base layers for a certain number of epochs
        and then open all layers for training.

        Reference: https://arxiv.org/abs/1611.05244
        """

        if (epoch + 1) <= fixbase_epoch and open_layers is not None:
            print('* Only train {} (epoch: {}/{})'.format(
                open_layers, epoch + 1, fixbase_epoch))

            for model in self.models.values():
                open_specified_layers(model, open_layers, strict=False)
        else:
            for model in self.models.values():
                open_all_layers(model)
Beispiel #5
0
    def _two_stepped_transfer_learning(self,
                                       epoch,
                                       fixbase_epoch,
                                       open_layers,
                                       model=None):
        """Two stepped transfer learning.

        The idea is to freeze base layers for a certain number of epochs
        and then open all layers for training.

        Reference: https://arxiv.org/abs/1611.05244
        """
        model1 = self.model1
        model2 = self.model2

        if (epoch + 1) <= fixbase_epoch and open_layers is not None:
            print('* Only train {} (epoch: {}/{})'.format(
                open_layers, epoch + 1, fixbase_epoch))
            open_specified_layers(model1, open_layers)
            open_specified_layers(model2, open_layers)
        else:
            open_all_layers(model1)
            open_all_layers(model2)
 def _unfreeze_aux_models(self):
     for model_name in self.model_names_to_freeze:
         model = self.models[model_name]
         model.train()
         open_all_layers(model)
Beispiel #7
0
    def train(self, epoch, max_epoch, trainloader, fixbase_epoch=0, open_layers=None, print_freq=10):
        losses1 = AverageMeter()
        losses2 = AverageMeter()
        accs1 = AverageMeter()
        accs2 = AverageMeter()
        batch_time = AverageMeter()
        data_time = AverageMeter()

        self.model.train()
        if (epoch+1)<=fixbase_epoch and open_layers is not None:
            print('* Only train {} (epoch: {}/{})'.format(open_layers, epoch+1, fixbase_epoch))
            open_specified_layers(self.model, open_layers)
        else:
            open_all_layers(self.model)

        num_batches = len(trainloader)
        end = time.time()
        for batch_idx, data in enumerate(trainloader):
            data_time.update(time.time() - end)

            imgs, pids = self._parse_data_for_train(data)
            if self.use_gpu:
                imgs = imgs.cuda()
                pids = pids.cuda()
            
            self.optimizer.zero_grad()
            output1, output2 = self.model(imgs)
            
            b = imgs.size(0)
            loss1 = self._compute_loss(self.criterion, output1, pids[:b//2])
            loss2 = self._compute_loss(self.criterion, output2, pids[b//2:b])
            loss = (loss1 + loss2) * 0.5
            loss.backward()
            self.optimizer.step()

            batch_time.update(time.time() - end)

            losses1.update(loss1.item(), pids[:b//2].size(0))
            losses2.update(loss2.item(), pids[b//2:b].size(0))
            accs1.update(metrics.accuracy(output1, pids[:b//2])[0].item())
            accs2.update(metrics.accuracy(output2, pids[b//2:b])[0].item())


            if (batch_idx+1) % print_freq == 0:
                # estimate remaining time
                eta_seconds = batch_time.avg * (num_batches-(batch_idx+1) + (max_epoch-(epoch+1))*num_batches)
                eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
                print('Epoch: [{0}/{1}][{2}/{3}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss1 {loss1.val:.4f} ({loss1.avg:.4f})\t'
                      'Loss2 {loss2.val:.4f} ({loss2.avg:.4f})\t'
                      'Acc1 {acc1.val:.2f} ({acc1.avg:.2f})\t'
                      'Acc2 {acc2.val:.2f} ({acc2.avg:.2f})\t'
                      'Lr {lr:.6f}\t'
                      'eta {eta}'.format(
                      epoch+1, max_epoch, batch_idx+1, num_batches,
                      batch_time=batch_time,
                      data_time=data_time,
                      loss1=losses1,
                      loss2=losses2,
                      acc1=accs1,
                      acc2=accs2,
                      lr=self.optimizer.param_groups[0]['lr'],
                      eta=eta_str
                    )
                )

            if self.writer is not None:
                n_iter = epoch * num_batches + batch_idx
                self.writer.add_scalar('Train/Time', batch_time.avg, n_iter)
                self.writer.add_scalar('Train/Data', data_time.avg, n_iter)
                self.writer.add_scalar('Train/Loss1', losses1.avg, n_iter)
                self.writer.add_scalar('Train/Loss2', losses2.avg, n_iter)
                self.writer.add_scalar('Train/Acc1', accs1.avg, n_iter)
                self.writer.add_scalar('Train/Acc2', accs2.avg, n_iter)
                self.writer.add_scalar('Train/Lr', self.optimizer.param_groups[0]['lr'], n_iter)
            
            end = time.time()

        if self.scheduler is not None:
            self.scheduler.step()
Beispiel #8
0
    def train(self, epoch, max_epoch, trainloader, fixbase_epoch=0, open_layers=None, print_freq=10):
        losses_t = AverageMeter()
        losses_x = AverageMeter()
        losses = AverageMeter()
        accs = AverageMeter()
        batch_time = AverageMeter()
        data_time = AverageMeter()

        self.model.train()
        if (epoch+1)<=fixbase_epoch and open_layers is not None:
            print('* Only train {} (epoch: {}/{})'.format(open_layers, epoch+1, fixbase_epoch))
            open_specified_layers(self.model, open_layers)
        else:
            open_all_layers(self.model)

        end = time.time()
        for batch_idx, data in enumerate(trainloader):
            data_time.update(time.time() - end)
            num_batches = len(trainloader)
            global_step = num_batches * epoch + batch_idx

            imgs, pids = self._parse_data_for_train(data)
            if self.use_gpu:
                imgs = imgs.cuda()
                pids = pids.cuda()
            
            self.optimizer.zero_grad()
            outputs, features = self.model(imgs)
            loss_t = self._compute_loss(self.criterion_t, features, pids)
            loss_x = self._compute_loss(self.criterion_x, outputs, pids)
            loss = self.weight_t * loss_t + self.weight_x * loss_x
            loss.backward()
            self.optimizer.step()

            batch_time.update(time.time() - end)

            losses_t.update(loss_t.item(), pids.size(0))
            losses_x.update(loss_x.item(), pids.size(0))
            losses.update(loss.item(), pids.size(0))
            accs.update(metrics.accuracy(outputs, pids, topk=(1,))[0].item())

            # write to Tensorboard & comet.ml

            #self.writer.add_scalars('optim/accs',accs.val,global_step)
            self.experiment.log_metric('optim/accs',accs.val,step=global_step)

            #self.writer.add_scalar('optim/loss',losses.val,global_step) # loss, loss.item() or losses.val ??
            self.experiment.log_metric('optim/loss',losses.val,step=global_step) 
            #self.writer.add_scalar('optim/loss_triplet',losses_t.val,global_step) 
            self.experiment.log_metric('optim/loss_triplet',losses_t.val,step=global_step)
            #self.writer.add_scalar('optim/loss_softmax',losses_x.val,global_step) 
            self.experiment.log_metric('optim/loss_softmax',losses_x.val,step=global_step)

            #self.writer.add_scalar('optim/lr',self.optimizer.param_groups[0]['lr'],global_step)
            self.experiment.log_metric('optim/lr',self.optimizer.param_groups[0]['lr'],step=global_step)

            if (batch_idx+1) % print_freq == 0:
                # estimate remaining time
                num_batches = len(trainloader)
                eta_seconds = batch_time.avg * (num_batches-(batch_idx+1) + (max_epoch-(epoch+1))*num_batches)
                eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
                print('Epoch: [{0}/{1}][{2}/{3}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Loss_t {loss_t.val:.4f} ({loss_t.avg:.4f})\t'
                      'Loss_x {loss_x.val:.4f} ({loss_x.avg:.4f})\t'
                      'Acc {acc.val:.2f} ({acc.avg:.2f})\t'
                      'Lr {lr:.6f}\t'
                      'Eta {eta}'.format(
                      epoch+1, max_epoch, batch_idx+1, len(trainloader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      loss_t=losses_t,
                      loss_x=losses_x,
                      acc=accs,
                      lr=self.optimizer.param_groups[0]['lr'],
                      eta=eta_str
                    )
                )
                self.writer.add_scalar('eta',eta_seconds,global_step)
                self.experiment.log_metric('eta',eta_seconds,step=global_step)
            
            end = time.time()

        if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            self.scheduler.step(losses.val)
        elif self.scheduler is not None:
            self.scheduler.step()
Beispiel #9
0
    def train(self,
              epoch,
              max_epoch,
              trainloader,
              fixbase_epoch=0,
              open_layers=None,
              print_freq=10):
        losses = AverageMeter()
        base_losses = AverageMeter()
        my_losses = AverageMeter()
        density_losses = AverageMeter()
        accs = AverageMeter()
        batch_time = AverageMeter()
        data_time = AverageMeter()

        self.model.train()
        if (epoch + 1) <= fixbase_epoch and open_layers is not None:
            print('* Only train {} (epoch: {}/{})'.format(
                open_layers, epoch + 1, fixbase_epoch))
            open_specified_layers(self.model, open_layers)
        else:
            open_all_layers(self.model)

        num_batches = len(trainloader)
        end = time.time()
        for batch_idx, data in enumerate(trainloader):
            data_time.update(time.time() - end)

            imgs, pids = self._parse_data_for_train(data)
            if self.use_gpu:
                imgs = imgs.cuda()
                pids = pids.cuda()

            output = self.model(imgs)
            output, v = output

            # attention_loss = 0
            # for i in range(iy.shape[1]):
            #     attention_loss_i = self._compute_loss(self.criterion, iy[:, i, :], pids)
            #     attention_loss += attention_loss_i
            #     if (batch_idx + 1) % print_freq == 0:
            #         print("test: ", i, attention_loss_i)
            #self.optimizer.zero_grad()
            #attention_loss.backward(retain_graph=True)
            #self.optimizer.step()

            #self.optimizer.zero_grad()
            #print(output.shape)
            #my = my.squeeze()
            #print(my.shape)
            #Plabels = torch.range(0, my.shape[1] - 1).long().cuda().repeat(imgs.shape[0], 1, 1, 1).reshape(-1)
            #Plabels = torch.range(0, my.shape[1] - 1).long().cuda().repeat(my.shape[0], 1, 1, 1).view(-1, my.shape[1])

            #print(Plabels.shape, Plabels)
            #my_loss = torch.nn.CrossEntropyLoss()(my, Plabels)
            base_loss = self._compute_loss(self.criterion, output, pids)
            loss = base_loss  # + 0.1 * my_loss #- density

            loss.backward()
            self.optimizer.step()

            batch_time.update(time.time() - end)
            losses.update(loss.item(), pids.size(0))
            base_losses.update(base_loss.item(), pids.size(0))
            #my_losses.update(my_loss.item(), pids.size(0))
            #density_losses.update(density.item(), pids.size(0))

            accs.update(metrics.accuracy(output, pids)[0].item())

            if (batch_idx + 1) % print_freq == 0:
                # estimate remaining time
                eta_seconds = batch_time.avg * (num_batches - (batch_idx + 1) +
                                                (max_epoch -
                                                 (epoch + 1)) * num_batches)
                eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
                print('Epoch: [{0}/{1}][{2}/{3}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Acc {acc.val:.2f} ({acc.avg:.2f})\t'
                      'Lr {lr:.6f}\t'
                      'eta {eta}'.format(
                          epoch + 1,
                          max_epoch,
                          batch_idx + 1,
                          num_batches,
                          batch_time=batch_time,
                          data_time=data_time,
                          loss=losses,
                          acc=accs,
                          lr=self.optimizer.param_groups[0]['lr'],
                          eta=eta_str))

            if self.writer is not None:
                n_iter = epoch * num_batches + batch_idx
                self.writer.add_scalar('Train/Time', batch_time.avg, n_iter)
                self.writer.add_scalar('Train/Data', data_time.avg, n_iter)
                self.writer.add_scalar('Train/Loss', losses.avg, n_iter)
                self.writer.add_scalar('Train/Base_Loss', base_losses.avg,
                                       n_iter)
                #self.writer.add_scalar('Train/My_Loss', my_losses.avg, n_iter)
                self.writer.add_scalar('Train/Density_loss',
                                       density_losses.avg, n_iter)
                self.writer.add_scalar('Train/Acc', accs.avg, n_iter)
                self.writer.add_scalar('Train/Lr',
                                       self.optimizer.param_groups[0]['lr'],
                                       n_iter)

            end = time.time()

        if self.scheduler is not None:
            self.scheduler.step()
Beispiel #10
0
    def train(self, epoch, max_epoch, trainloader, fixbase_epoch=0, open_layers=None, print_freq=10):
        losses = AverageMeter()
        reg_ow_loss = AverageMeter()
        metric_loss = AverageMeter()
        accs = AverageMeter()
        batch_time = AverageMeter()
        data_time = AverageMeter()

        self.model.train()
        if (epoch + 1) <= fixbase_epoch and open_layers is not None:
            print('* Only train {} (epoch: {}/{})'.format(open_layers, epoch+1, fixbase_epoch))
            open_specified_layers(self.model, open_layers)
        else:
            open_all_layers(self.model)

        num_batches = len(trainloader)
        start_time = time.time()
        for batch_idx, data in enumerate(trainloader):
            data_time.update(time.time() - start_time)

            imgs, pids = self._parse_data_for_train(data)
            imgs, pids = self._apply_batch_transform(imgs, pids)
            if self.use_gpu:
                imgs = imgs.cuda()
                pids = pids.cuda()

            self.optimizer.zero_grad()
            if self.metric_loss is not None:
                embeddings, outputs = self.model(imgs, get_embeddings=True)
            else:
                outputs = self.model(imgs)

            loss = self._compute_loss(self.criterion, outputs, pids)

            if (epoch + 1) > fixbase_epoch:
                reg_loss = self.regularizer(self.model)
                reg_ow_loss.update(reg_loss.item(), pids.size(0))
                loss += reg_loss

            if self.metric_loss is not None:
                metric_val = self.metric_loss(F.normalize(embeddings, dim=1),
                                              outputs, pids)
                loss += metric_val
                metric_loss.update(metric_val.item(), pids.size(0))

            loss.backward()
            self.optimizer.step()

            losses.update(loss.item(), pids.size(0))
            accs.update(metrics.accuracy(outputs, pids)[0].item())
            batch_time.update(time.time() - start_time)

            if print_freq > 0 and (batch_idx + 1) % print_freq == 0:
                eta_seconds = batch_time.avg * (num_batches-(batch_idx + 1) + (max_epoch - (epoch + 1)) * num_batches)
                eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
                print('Epoch: [{0}/{1}][{2}/{3}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'AUX Losses {aux_losses.val:.4f} ({aux_losses.avg:.4f})\t'
                      'Acc {acc.val:.2f} ({acc.avg:.2f})\t'
                      'Lr {lr:.6f}\t'
                      'eta {eta}'.
                      format(
                          epoch + 1, max_epoch, batch_idx + 1, num_batches,
                          batch_time=batch_time,
                          data_time=data_time,
                          aux_losses=metric_loss,
                          loss=losses,
                          acc=accs,
                          lr=self.optimizer.param_groups[0]['lr'],
                          eta=eta_str,
                      )
                )

                if self.writer is not None:
                    n_iter = epoch * num_batches + batch_idx
                    self.writer.add_scalar('Train/Time', batch_time.avg, n_iter)
                    self.writer.add_scalar('Train/Data', data_time.avg, n_iter)
                    info = self.criterion.get_last_info()
                    for k in info:
                        self.writer.add_scalar('AUX info/' + k, info[k], n_iter)
                    self.writer.add_scalar('Loss/train', losses.avg, n_iter)
                    if (epoch + 1) > fixbase_epoch:
                        self.writer.add_scalar('Loss/reg_ow', reg_ow_loss.avg, n_iter)
                    self.writer.add_scalar('Accuracy/train', accs.avg, n_iter)
                    self.writer.add_scalar('Learning rate', self.optimizer.param_groups[0]['lr'], n_iter)
                    if self.metric_loss is not None:
                        self.writer.add_scalar('Loss/local_push_loss',
                                               metric_val.item(), n_iter)
            start_time = time.time()

        if self.scheduler is not None:
            self.scheduler.step()
    def train(
            self,
            epoch,
            max_epoch,
            writer,
            print_freq=10,
            fixbase_epoch=0,
            open_layers=None
    ):
        losses_t = AverageMeter()
        losses_x = AverageMeter()
        losses_recons = AverageMeter()
        accs = AverageMeter()
        batch_time = AverageMeter()
        data_time = AverageMeter()

        self.model.train()
       
        open_all_layers(self.model)

        num_batches = len(self.train_loader)
        end = time.time()
        for batch_idx, data in enumerate(self.train_loader):
            data_time.update(time.time() - end)

            imgs, pids = self._parse_data_for_train(data)
            imgs_clean=imgs.clone()
            if self.use_gpu:
                imgs = imgs.cuda()
                imgs_clean = imgs_clean.cuda()
                pids = pids.cuda()
            labelss=[]
            if epoch >= 0 and epoch < 15:
                randmt = RandomErasing(probability=0.5,sl=0.07, sh=0.3)
                for i, img in enumerate(imgs):
                   
                   imgs[i],p = randmt(img)
                   labelss.append(p)
               
            if epoch >= 15:
                randmt = RandomErasing(probability=0.5,sl=0.1, sh=0.3)
                for i, img in enumerate(imgs):
                   
                   imgs[i],p = randmt(img)
                   labelss.append(p)

            binary_labels = torch.tensor(np.asarray(labelss)).cuda()
            self.optimizer.zero_grad()
            
            outputs, outputs2, recons,bin_out1,bin_out2, bin_out3 = self.model(imgs )
            loss_mse = self.criterion_mse(recons, imgs_clean)
            loss = self.mgn_loss(outputs, pids)
            
            occ_loss1 = self.BCE_criterion(bin_out1.squeeze(1),binary_labels.float() )
            occ_loss2 = self.BCE_criterion(bin_out2.squeeze(1),binary_labels.float() )
            occ_loss3 = self.BCE_criterion(bin_out3.squeeze(1),binary_labels.float() )


            loss = loss + .05*loss_mse + 0.1*occ_loss1 + 0.1*occ_loss2+0.1*occ_loss3
            #loss = self.weight_t * loss_t + self.weight_x * loss_x #+ #self.weight_r*loss_mse
            loss.backward()
            self.optimizer.step()

            batch_time.update(time.time() - end)

            #losses_t.update(loss_t.item(), pids.size(0))
            losses_x.update(loss.item(), pids.size(0))
            losses_recons.update(occ_loss1.item(), binary_labels.size(0))
            accs.update(metrics.accuracy(outputs, pids)[0].item())

            if (batch_idx + 1) % print_freq == 0:
                # estimate remaining time
                eta_seconds = batch_time.avg * (
                        num_batches - (batch_idx + 1) + (max_epoch -
                                                         (epoch + 1)) * num_batches
                )
                eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
                print(
                    'Epoch: [{0}/{1}][{2}/{3}]\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                    #'Loss_t {loss_t.val:.4f} ({loss_t.avg:.4f})\t'
                    'Loss_x {loss_x.val:.4f} ({loss_x.avg:.4f})\t'
                    'Loss_Occlusion {loss_r.val:.4f} ({loss_r.avg:.4f})\t'             
                    'Acc {acc.val:.2f} ({acc.avg:.2f})\t'
                    'Lr {lr:.6f}\t'
                    'eta {eta}'.format(
                        epoch + 1,
                        max_epoch,
                        batch_idx + 1,
                        num_batches,
                        batch_time=batch_time,
                        data_time=data_time,
                        #loss_t=losses_t,
                        loss_x=losses_x,
                        loss_r = losses_recons,
                        acc=accs,
                        lr=self.optimizer.param_groups[0]['lr'],
                        eta=eta_str
                    )
                )
            writer= None
            if writer is not None:
                n_iter = epoch * num_batches + batch_idx
                writer.add_scalar('Train/Time', batch_time.avg, n_iter)
                writer.add_scalar('Train/Data', data_time.avg, n_iter)
                writer.add_scalar('Train/Loss_t', losses_t.avg, n_iter)
                writer.add_scalar('Train/Loss_x', losses_x.avg, n_iter)
                writer.add_scalar('Train/Acc', accs.avg, n_iter)
                writer.add_scalar(
                    'Train/Lr', self.optimizer.param_groups[0]['lr'], n_iter
                )

            end = time.time()

        if self.scheduler is not None:
            self.scheduler.step()
Beispiel #12
0
    def train(
            self,
            epoch,
            max_epoch,
            writer,
            print_freq=10,
            fixbase_epoch=0,
            open_layers=None,
    ):
        losses_triplet = AverageMeter()
        losses_softmax = AverageMeter()
        losses_mmd_bc = AverageMeter()
        losses_mmd_wc = AverageMeter()
        losses_mmd_global = AverageMeter()
        batch_time = AverageMeter()
        data_time = AverageMeter()

        self.model.train()
        if (epoch + 1) <= fixbase_epoch and open_layers is not None:
            print(
                '* Only train {} (epoch: {}/{})'.format(
                    open_layers, epoch + 1, fixbase_epoch
                )
            )
            open_specified_layers(self.model, open_layers)
        else:
            open_all_layers(self.model)

        num_batches = len(self.train_loader)
        end = time.time()

# -------------------------------------------------------------------------------------------------------------------- #
        for batch_idx, (data, data_t) in enumerate(zip(self.train_loader, self.train_loader_t)):
            data_time.update(time.time() - end)

            imgs, pids = self._parse_data_for_train(data)
            if self.use_gpu:
                imgs = imgs.cuda()
                pids = pids.cuda()

            imgs_t, pids_t = self._parse_data_for_train(data_t)
            if self.use_gpu:
                imgs_t = imgs_t.cuda()

            self.optimizer.zero_grad()

            outputs, features = self.model(imgs)
            outputs_t, features_t = self.model(imgs_t)

            loss_t = self._compute_loss(self.criterion_t, features, pids)
            loss_x = self._compute_loss(self.criterion_x, outputs, pids)
            loss = loss_t + loss_x

            if epoch > 20:
                loss_mmd_wc, loss_mmd_bc, loss_mmd_global = self._compute_loss(self.criterion_mmd, features, features_t)
                #loss = loss_t + loss_x + loss_mmd_bc + loss_mmd_wc
                loss = loss_t + loss_x + loss_mmd_global + loss_mmd_bc + loss_mmd_wc

                if False:
                    loss_t = torch.tensor(0)
                    loss_x = torch.tensor(0)
                    #loss = loss_mmd_bc + loss_mmd_wc
                    loss = loss_mmd_bc + loss_mmd_wc + loss_mmd_global


            loss.backward()
            self.optimizer.step()
# -------------------------------------------------------------------------------------------------------------------- #

            batch_time.update(time.time() - end)
            losses_triplet.update(loss_t.item(), pids.size(0))
            losses_softmax.update(loss_x.item(), pids.size(0))
            if epoch > 24:
                losses_mmd_bc.update(loss_mmd_bc.item(), pids.size(0))
                losses_mmd_wc.update(loss_mmd_wc.item(), pids.size(0))
                losses_mmd_global.update(loss_mmd_global.item(), pids.size(0))

            if (batch_idx + 1) % print_freq == 0:
                # estimate remaining time
                eta_seconds = batch_time.avg * (
                        num_batches - (batch_idx + 1) + (max_epoch -
                                                         (epoch + 1)) * num_batches
                )
                eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
                print(
                    'Epoch: [{0}/{1}][{2}/{3}]\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Loss_t {losses1.val:.4f} ({losses1.avg:.4f})\t'
                    'Loss_x {losses2.val:.4f} ({losses2.avg:.4f})\t'
                    'Loss_mmd_wc {losses3.val:.4f} ({losses3.avg:.4f})\t'
                    'Loss_mmd_bc {losses4.val:.4f} ({losses4.avg:.4f})\t'
                    'Loss_mmd_global {losses5.val:.4f} ({losses5.avg:.4f})\t'
                    'eta {eta}'.format(
                        epoch + 1,
                        max_epoch,
                        batch_idx + 1,
                        num_batches,
                        batch_time=batch_time,
                        losses1=losses_triplet,
                        losses2=losses_softmax,
                        losses3=losses_mmd_wc,
                        losses4=losses_mmd_bc,
                        losses5=losses_mmd_global,
                        eta=eta_str
                    )
                )

            if writer is not None:
                n_iter = epoch * num_batches + batch_idx
                writer.add_scalar('Train/Time', batch_time.avg, n_iter)
                writer.add_scalar('Train/Loss_triplet', losses_triplet.avg, n_iter)
                writer.add_scalar('Train/Loss_softmax', losses_softmax.avg, n_iter)
                writer.add_scalar('Train/Loss_mmd_bc', losses_mmd_bc.avg, n_iter)
                writer.add_scalar('Train/Loss_mmd_wc', losses_mmd_wc.avg, n_iter)
                writer.add_scalar('Train/Loss_mmd_global', losses_mmd_global.avg, n_iter)
                writer.add_scalar(
                    'Train/Lr', self.optimizer.param_groups[0]['lr'], n_iter
                )

            end = time.time()

        if self.scheduler is not None:
            self.scheduler.step()

        print_distri = False

        if print_distri:

            instances = self.datamanager.train_loader.sampler.num_instances
            batch_size = self.datamanager.train_loader.batch_size
            feature_size = 2048 # features_t.shape[1]  # 2048
            t = torch.reshape(features_t, (int(batch_size / instances), instances, feature_size))

            #  and compute bc/wc euclidean distance
            bct = compute_distance_matrix(t[0], t[0])
            wct = compute_distance_matrix(t[0], t[1])
            for i in t[1:]:
                bct = torch.cat((bct, compute_distance_matrix(i, i)))
                for j in t:
                    if j is not i:
                        wct = torch.cat((wct, compute_distance_matrix(i, j)))

            s = torch.reshape(features, (int(batch_size / instances), instances, feature_size))
            bcs = compute_distance_matrix(s[0], s[0])
            wcs = compute_distance_matrix(s[0], s[1])
            for i in s[1:]:
                bcs = torch.cat((bcs, compute_distance_matrix(i, i)))
                for j in s:
                    if j is not i:
                        wcs = torch.cat((wcs, compute_distance_matrix(i, j)))

            bcs = bcs.detach()
            wcs = wcs.detach()

            b_c = [x.cpu().detach().item() for x in bcs.flatten() if x > 0.000001]
            w_c = [x.cpu().detach().item() for x in wcs.flatten() if x > 0.000001]
            data_bc = norm.rvs(b_c)
            sns.distplot(data_bc, bins='auto', fit=norm, kde=False, label='from the same class (within class)')
            data_wc = norm.rvs(w_c)
            sns.distplot(data_wc, bins='auto', fit=norm, kde=False, label='from different class (between class)')
            plt.xlabel('Euclidean distance')
            plt.ylabel('Frequence of apparition')
            plt.title('Source Domain')
            plt.legend()
            plt.show()

            b_ct = [x.cpu().detach().item() for x in bct.flatten() if x > 0.1]
            w_ct = [x.cpu().detach().item() for x in wct.flatten() if x > 0.1]
            data_bc = norm.rvs(b_ct)
            sns.distplot(data_bc, bins='auto', fit=norm, kde=False, label='from the same class (within class)')
            data_wc = norm.rvs(w_ct)
            sns.distplot(data_wc, bins='auto', fit=norm, kde=False, label='from different class (between class)')
            plt.xlabel('Euclidean distance')
            plt.ylabel('Frequence of apparition')
            plt.title('Target Domain')
            plt.legend()
            plt.show()
Beispiel #13
0
    def train(self,
              epoch,
              max_epoch,
              trainloader,
              fixbase_epoch=0,
              open_layers=None,
              print_freq=10):
        use_matching_loss = False
        if epoch >= self.reg_matching_score_epoch:
            use_matching_loss = True
        losses = AverageMeter()
        accs = AverageMeter()
        batch_time = AverageMeter()
        data_time = AverageMeter()
        att_losses = AverageMeter()
        part_losses = AverageMeter()
        matching_losses = AverageMeter()

        self.model.train()
        if (epoch + 1) <= fixbase_epoch and open_layers is not None:
            print('* Only train {} (epoch: {}/{})'.format(
                open_layers, epoch + 1, fixbase_epoch))
            open_specified_layers(self.model, open_layers)
        else:
            open_all_layers(self.model)
        end = time.time()
        for batch_idx, data in enumerate(trainloader):
            data_time.update(time.time() - end)

            imgs, pids, pose_heatmaps = self._parse_data_for_train(data)
            if self.use_gpu:
                imgs = imgs.cuda()
                pids = pids.cuda()
                pose_heatmaps = pose_heatmaps.cuda()

            self.optimizer.zero_grad()
            outputs, attmaps, part_score, v_g = self.model(imgs, pose_heatmaps)
            #classification loss
            loss_class = self._compute_loss(self.criterion, outputs, pids)
            # using for weighting each part with visibility
            # loss_class = self._compute_loss(self.criterion, outputs, pids, part_score.detach())
            loss_matching, loss_partconstr = self.part_c_criterion(
                v_g, pids, part_score, use_matching_loss)
            # add matching loss
            loss = loss_class + loss_partconstr
            # visibility verification loss
            if use_matching_loss:
                loss = loss + loss_matching
                matching_losses.update(loss_matching.item(), pids.size(0))
            if self.use_att_loss:
                loss_att = self.att_criterion(attmaps)
                loss = loss + loss_att
                att_losses.update(loss_att.item(), pids.size(0))
            loss.backward()
            self.optimizer.step()

            batch_time.update(time.time() - end)
            losses.update(loss.item(), pids.size(0))
            part_losses.update(loss_partconstr.item(), pids.size(0))
            accs.update(metrics.accuracy(outputs, pids)[0].item())

            if (batch_idx + 1) % print_freq == 0:
                # estimate remaining time
                num_batches = len(trainloader)
                eta_seconds = batch_time.avg * (num_batches - (batch_idx + 1) +
                                                (max_epoch -
                                                 (epoch + 1)) * num_batches)
                eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
                print('Epoch: [{0}/{1}][{2}/{3}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'part_Loss {loss_part.val:.4f} ({loss_part.avg:.4f})\t'
                      'Acc {acc.val:.2f} ({acc.avg:.2f})\t'
                      'Lr {lr:.6f}\t'
                      'Eta {eta}'.format(
                          epoch + 1,
                          max_epoch,
                          batch_idx + 1,
                          len(trainloader),
                          batch_time=batch_time,
                          data_time=data_time,
                          loss=losses,
                          loss_part=part_losses,
                          acc=accs,
                          lr=self.optimizer.param_groups[0]['lr'],
                          eta=eta_str),
                      end='\t')
                if self.use_att_loss:
                    print(
                        'attLoss {attloss.val:.4f} ({attloss.avg:.4f})'.format(
                            attloss=att_losses),
                        end='\t')
                if use_matching_loss:
                    print(
                        'matchLoss {match_loss.val:.4f} ({match_loss.avg:.4f})'
                        .format(match_loss=matching_losses),
                        end='\t')
                print('\n')

            end = time.time()

        if self.scheduler is not None:
            self.scheduler.step()
Beispiel #14
0
    def train(
        self,
        epoch,
        max_epoch,
        writer,
        print_freq=1,
        fixbase_epoch=0,
        open_layers=None,
    ):
        losses_triplet = AverageMeter()
        losses_softmax = AverageMeter()
        losses_recons_s = AverageMeter()
        losses_recons_t = AverageMeter()
        losses_mmd_bc = AverageMeter()
        losses_mmd_wc = AverageMeter()
        losses_mmd_global = AverageMeter()
        losses_local = AverageMeter()

        batch_time = AverageMeter()
        data_time = AverageMeter()

        self.model.train()
        if (epoch + 1) <= fixbase_epoch and open_layers is not None:
            print('* Only train {} (epoch: {}/{})'.format(
                open_layers, epoch + 1, fixbase_epoch))
            open_specified_layers(self.model, open_layers)
        else:
            open_all_layers(self.model)

        num_batches = len(self.train_loader)
        end = time.time()
        weight_r = self.weight_r
        # -------------------------------------------------------------------------------------------------------------------- #
        for batch_idx, (data, data_t) in enumerate(
                zip(self.train_loader, self.train_loader_t)):
            data_time.update(time.time() - end)

            imgs, pids = self._parse_data_for_train(data)
            if self.use_gpu:
                imgs = imgs.cuda()
                pids = pids.cuda()

            imgs_t, pids_t = self._parse_data_for_train(data_t)
            if self.use_gpu:
                imgs_t = imgs_t.cuda()

            self.optimizer.zero_grad()
            noisy_imgs = self.random(imgs)
            outputs, part_outs, features, recons, z, mean, var, local_feat = self.model(
                noisy_imgs)
            parts_loss = 0

            for i in range(len(part_outs)):
                out = part_outs[i]

                parts_loss += self._compute_loss(
                    self.criterion_x, out, pids)  #  self.criterion( out, pids)

            parts_loss = parts_loss / len(part_outs)
            #print("local feats")
            #print(local_feat.shape)
            #print("global feats ")
            #print(local_feat.reshape(local_feat.size(0),-1).t().shape)

            imgs_t = self.random2(imgs_t)
            outputs_t, parts_out_t, features_t, recons_t, z_t, mean_t, var_t, local_feat_t = self.model(
                imgs_t)

            loss_t = self._compute_loss(self.criterion_t, features, pids)
            loss_x = self._compute_loss(self.criterion_x, outputs, pids)
            loss_r1 = self.loss_vae(imgs, recons, mean, var)
            loss_r2 = self.loss_vae(imgs_t, recons_t, mean_t, var_t)

            dist_mat_s = self.get_local_correl(local_feat)
            dist_mat_t = self.get_local_correl(local_feat_t)

            dist_mat_s = dist_mat_s.detach()
            local_loss = self.criterion_mmd.mmd_rbf_noaccelerate(
                dist_mat_s, dist_mat_t)

            kl_loss = torch.tensor(0)
            #loss = loss_t + loss_x + weight_r*loss_r1 +  (weight_r*2)*loss_r2 + loss_mmd_global #+ 0.1*kl_loss
            loss_mmd_wc, loss_mmd_bc, loss_mmd_global = self._compute_loss(
                self.criterion_mmd, features, features_t)
            loss = loss_t + loss_x + weight_r * loss_r1 + 0 * loss_r2 + loss_mmd_wc + loss_mmd_bc + loss_mmd_global + parts_loss  #weight_r2 =0 is best
            if epoch > 10:

                #loss = loss_t + loss_x  + weight_r*loss_r1  + (weight_r)*loss_r2  +  loss_mmd_wc + loss_mmd_bc  + loss_mmd_global

                if False:
                    loss_mmd_bc = torch.tensor(0)
                    loss_mmd_global = torch.tensor(0)
                    loss_mmd_wc = torch.tensor(0)
                    kl_loss = torch.tensor(0)

                    #loss = loss_mmd_bc + loss_mmd_wc
                    loss = loss_t + loss_x + weight_r * loss_r1 + (
                        weight_r
                    ) * loss_r2 + loss_mmd_wc + loss_mmd_bc + loss_mmd_global

            loss.backward()
            self.optimizer.step()
            # -------------------------------------------------------------------------------------------------------------------- #

            batch_time.update(time.time() - end)
            losses_triplet.update(loss_t.item(), pids.size(0))
            losses_softmax.update(loss_x.item(), pids.size(0))
            losses_recons_s.update(loss_r1.item(), pids.size(0))
            losses_recons_t.update(loss_r2.item(), pids.size(0))

            losses_local.update(local_loss.item(), pids.size(0))

            if (batch_idx + 1) % print_freq == 0:
                # estimate remaining time
                eta_seconds = batch_time.avg * (num_batches - (batch_idx + 1) +
                                                (max_epoch -
                                                 (epoch + 1)) * num_batches)
                eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
                print('Epoch: [{0}/{1}][{2}/{3}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss_t {losses1.val:.4f} ({losses1.avg:.4f})\t'
                      'Loss_x {losses2.val:.4f} ({losses2.avg:.4f})\t'
                      'Loss_reconsS {losses4.val:.4f} ({losses4.avg:.4f})\t'
                      'Loss_reconsT {losses5.val:.4f} ({losses5.avg:.4f})\t'
                      'Loss_local {losses6.val:.4f} ({losses6.avg:.4f})\t'
                      'eta {eta}'.format(epoch + 1,
                                         max_epoch,
                                         batch_idx + 1,
                                         num_batches,
                                         batch_time=batch_time,
                                         losses1=losses_triplet,
                                         losses2=losses_softmax,
                                         losses4=losses_recons_s,
                                         losses5=losses_recons_t,
                                         losses6=losses_local,
                                         eta=eta_str))

            if writer is not None:
                n_iter = epoch * num_batches + batch_idx
                writer.add_scalar('Train/Time', batch_time.avg, n_iter)
                writer.add_scalar('Train/Loss_triplet', losses_triplet.avg,
                                  n_iter)
                writer.add_scalar('Train/Loss_softmax', losses_softmax.avg,
                                  n_iter)

                writer.add_scalar('Train/Loss_recons_s', losses_recons_s.avg,
                                  n_iter)
                writer.add_scalar('Train/Loss_recons_t', losses_recons_t.avg,
                                  n_iter)

            end = time.time()

        if self.scheduler is not None:
            self.scheduler.step()

        print_distri = False

        if print_distri:
            print("Printing distribution")
            instances = self.datamanager.train_loader.sampler.num_instances
            batch_size = self.datamanager.train_loader.batch_size
            feature_size = 1024  # features_t.shape[1]  # 2048
            #print("local feature size!!!")
            #print(local_feat_t.shape)
            local_feat_t = local_feat_t.reshape(local_feat_t.size(0), -1)
            t = torch.reshape(
                local_feat_t,
                (int(batch_size / instances), instances, feature_size))

            #  and compute bc/wc euclidean distance
            bct = compute_distance_matrix(t[0], t[0])
            wct = compute_distance_matrix(t[0], t[1])
            for i in t[1:]:
                bct = torch.cat((bct, compute_distance_matrix(i, i)))
                for j in t:
                    if j is not i:
                        wct = torch.cat((wct, compute_distance_matrix(i, j)))

            s = torch.reshape(
                local_feat,
                (int(batch_size / instances), instances, feature_size))
            bcs = compute_distance_matrix(s[0], s[0])
            wcs = compute_distance_matrix(s[0], s[1])
            for i in s[1:]:
                bcs = torch.cat((bcs, compute_distance_matrix(i, i)))
                for j in s:
                    if j is not i:
                        wcs = torch.cat((wcs, compute_distance_matrix(i, j)))

            bcs = bcs.detach()
            wcs = wcs.detach()

            b_c = [
                x.cpu().detach().item() for x in bcs.flatten() if x > 0.000001
            ]
            w_c = [
                x.cpu().detach().item() for x in wcs.flatten() if x > 0.000001
            ]
            data_bc = norm.rvs(b_c)
            sns.distplot(data_bc,
                         bins='auto',
                         fit=norm,
                         kde=False,
                         label='from the same class (within class)')
            data_wc = norm.rvs(w_c)
            sns.distplot(data_wc,
                         bins='auto',
                         fit=norm,
                         kde=False,
                         label='from different class (between class)')
            plt.xlabel('Euclidean distance')
            plt.ylabel('Frequence of Occurance')
            plt.title('Source Domain')
            plt.legend()
            plt.savefig(
                "/export/livia/home/vision/mkiran/work/Person_Reid/Video_Person/Domain_Adapt/D-MMD/figs/Non_Occluded_distribution.png"
            )
            plt.clf()

            b_ct = [x.cpu().detach().item() for x in bct.flatten() if x > 0.1]
            w_ct = [x.cpu().detach().item() for x in wct.flatten() if x > 0.1]
            data_bc = norm.rvs(b_ct)
            sns.distplot(data_bc,
                         bins='auto',
                         fit=norm,
                         kde=False,
                         label='from the same class (within class)')
            data_wc = norm.rvs(w_ct)
            sns.distplot(data_wc,
                         bins='auto',
                         fit=norm,
                         kde=False,
                         label='from different class (between class)')
            plt.xlabel('Euclidean distance')
            plt.ylabel('Frequence of apparition')
            plt.title('Non-Occluded Data Domain')
            plt.legend()
            plt.savefig(
                "/export/livia/home/vision/mkiran/work/Person_Reid/Video_Person/Domain_Adapt/D-MMD/figs/Occluded_distribution.png"
            )
            plt.clf()
Beispiel #15
0
    def train(self,
              epoch,
              max_epoch,
              trainloader,
              fixbase_epoch=0,
              open_layers=None,
              print_freq=10):
        losses = AverageMeter()
        top_meters = [AverageMeter() for _ in range(5)]
        batch_time = AverageMeter()
        data_time = AverageMeter()

        self.model.train()
        if (epoch + 1) <= fixbase_epoch and open_layers is not None:
            print('* Only train {} (epoch: {}/{})'.format(
                open_layers, epoch + 1, fixbase_epoch))
            open_specified_layers(self.model, open_layers)
        else:
            open_all_layers(self.model)

        end = time.time()
        for batch_idx, data in enumerate(trainloader):
            data_time.update(time.time() - end)
            num_batches = len(trainloader)
            global_step = num_batches * epoch + batch_idx

            imgs, pids = self._parse_data_for_train(data)
            if self.use_gpu:
                imgs = imgs.cuda()
                pids = pids.cuda()

            self.optimizer.zero_grad()
            outputs = self.model(imgs)
            loss = self._compute_loss(self.criterion, outputs, pids)
            loss.backward()
            self.optimizer.step()

            batch_time.update(time.time() - end)

            losses.update(loss.item(), pids.size(0))
            accs = metrics.accuracy(outputs, pids, topk=(1, 2, 3, 4, 5))
            for i, meter in enumerate(top_meters):
                meter.update(accs[i].item())

            # write to Tensorboard & comet.ml
            accs_dict = {
                'train-accs-top-' + str(i + 1): float(r)
                for i, r in enumerate(accs)
            }

            for i, r in enumerate(accs):
                self.writer.add_scalars('optim/train-accs',
                                        {'top-' + str(i + 1): float(r)},
                                        global_step)
            self.experiment.log_metrics(accs_dict, step=global_step)

            self.writer.add_scalar(
                'optim/loss', losses.val,
                global_step)  # loss, loss.item() or losses.val ??
            # self.writer.add_scalar('optim/loss-avg',losses.avg,global_step)
            self.experiment.log_metric('optim/loss',
                                       losses.val,
                                       step=global_step)

            self.writer.add_scalar('optim/lr',
                                   self.optimizer.param_groups[0]['lr'],
                                   global_step)
            self.experiment.log_metric('optim/lr',
                                       self.optimizer.param_groups[0]['lr'],
                                       step=global_step)

            if (batch_idx + 1) % print_freq == 0:
                # estimate remaining time
                num_batches = len(trainloader)
                eta_seconds = batch_time.avg * (num_batches - (batch_idx + 1) +
                                                (max_epoch -
                                                 (epoch + 1)) * num_batches)
                eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
                print('Epoch: [{0}/{1}][{2}/{3}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Top-1 {r1.val:.2f} ({r1.avg:.2f})\t'
                      'Top-2 {r2.val:.2f} ({r2.avg:.2f})\t'
                      'Top-3 {r3.val:.2f} ({r3.avg:.2f})\t'
                      'Top-4 {r4.val:.2f} ({r4.avg:.2f})\t'
                      'Top-5 {r5.val:.2f} ({r5.avg:.2f})\t'
                      'Lr {lr:.6f}\t'
                      'Eta {eta}'.format(
                          epoch + 1,
                          max_epoch,
                          batch_idx + 1,
                          len(trainloader),
                          batch_time=batch_time,
                          data_time=data_time,
                          loss=losses,
                          r1=top_meters[0],
                          r2=top_meters[1],
                          r3=top_meters[2],
                          r4=top_meters[3],
                          r5=top_meters[4],
                          lr=self.optimizer.param_groups[0]['lr'],
                          eta=eta_str))
                self.writer.add_scalar('eta', eta_seconds, global_step)
                self.experiment.log_metric('eta',
                                           eta_seconds,
                                           step=global_step)

            end = time.time()

        if isinstance(self.scheduler,
                      torch.optim.lr_scheduler.ReduceLROnPlateau):
            self.scheduler.step(losses.val)
        elif self.scheduler is not None:
            self.scheduler.step()
Beispiel #16
0
    def train(self,
              epoch,
              max_epoch,
              writer,
              print_freq=10,
              fixbase_epoch=0,
              open_layers=None):
        losses_t = AverageMeter()
        losses_x = AverageMeter()
        accs = AverageMeter()
        batch_time = AverageMeter()
        data_time = AverageMeter()

        self.model.train()
        if (epoch + 1) <= fixbase_epoch and open_layers is not None:
            print('* Only train {} (epoch: {}/{})'.format(
                open_layers, epoch + 1, fixbase_epoch))
            open_specified_layers(self.model, open_layers)
        else:
            open_all_layers(self.model)

        num_batches = len(self.train_loader)
        end = time.time()
        for batch_idx, data in enumerate(self.train_loader):
            data_time.update(time.time() - end)

            imgs, pids = self._parse_data_for_train(data)
            if self.use_gpu:
                imgs = imgs.cuda()
                pids = pids.cuda()

            self.optimizer.zero_grad()
            outputs, features = self.model(imgs)
            loss_t = self._compute_loss(self.criterion_t, features, pids)
            loss_x = self._compute_loss(self.criterion_x, outputs, pids)
            loss = self.weight_t * loss_t + self.weight_x * loss_x
            loss.backward()
            self.optimizer.step()

            batch_time.update(time.time() - end)

            losses_t.update(loss_t.item(), pids.size(0))
            losses_x.update(loss_x.item(), pids.size(0))
            accs.update(metrics.accuracy(outputs, pids)[0].item())

            if (batch_idx + 1) % print_freq == 0:
                # estimate remaining time
                eta_seconds = batch_time.avg * (num_batches - (batch_idx + 1) +
                                                (max_epoch -
                                                 (epoch + 1)) * num_batches)
                eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
                print('Epoch: [{0}/{1}][{2}/{3}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss_t {loss_t.val:.4f} ({loss_t.avg:.4f})\t'
                      'Loss_x {loss_x.val:.4f} ({loss_x.avg:.4f})\t'
                      'Acc {acc.val:.2f} ({acc.avg:.2f})\t'
                      'Lr {lr:.6f}\t'
                      'eta {eta}'.format(
                          epoch + 1,
                          max_epoch,
                          batch_idx + 1,
                          num_batches,
                          batch_time=batch_time,
                          data_time=data_time,
                          loss_t=losses_t,
                          loss_x=losses_x,
                          acc=accs,
                          lr=self.optimizer.param_groups[0]['lr'],
                          eta=eta_str))

            if writer is not None:
                n_iter = epoch * num_batches + batch_idx
                writer.add_scalar('Train/Time', batch_time.avg, n_iter)
                writer.add_scalar('Train/Data', data_time.avg, n_iter)
                writer.add_scalar('Train/Loss_t', losses_t.avg, n_iter)
                writer.add_scalar('Train/Loss_x', losses_x.avg, n_iter)
                writer.add_scalar('Train/Acc', accs.avg, n_iter)
                writer.add_scalar('Train/Lr',
                                  self.optimizer.param_groups[0]['lr'], n_iter)

            end = time.time()

        if self.scheduler is not None:
            self.scheduler.step()
    def train(
        self,
        epoch,
        max_epoch,
        writer,
        fixbase_epoch=0,
        open_layers=None,
        print_freq=10
    ):
        losses = AverageMeter()
        accs = AverageMeter()
        batch_time = AverageMeter()
        data_time = AverageMeter()

        self.model.train()
        if (epoch + 1) <= fixbase_epoch and open_layers is not None:
            print(
                '* Only train {} (epoch: {}/{})'.format(
                    open_layers, epoch + 1, fixbase_epoch
                )
            )
            open_specified_layers(self.model, open_layers)
        else:
            open_all_layers(self.model)

        num_batches = len(self.train_loader)
        end = time.time()
        for batch_idx, data in enumerate(self.train_loader):
            data_time.update(time.time() - end)

            imgs, pids = self._parse_data_for_train(data)
            if self.use_gpu:
                imgs = imgs.cuda()
                pids = pids.cuda()

            # softmax temporature
            if self.fixed_lmda or self.lmda_decay_step == -1:
                lmda = self.init_lmda
            else:
                lmda = self.init_lmda * self.lmda_decay_rate**(
                    epoch // self.lmda_decay_step
                )
                if lmda < self.min_lmda:
                    lmda = self.min_lmda

            for k in range(self.mc_iter):
                outputs = self.model(imgs, lmda=lmda)
                loss = self._compute_loss(self.criterion, outputs, pids)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            batch_time.update(time.time() - end)

            losses.update(loss.item(), pids.size(0))
            accs.update(metrics.accuracy(outputs, pids)[0].item())

            if (batch_idx+1) % print_freq == 0:
                # estimate remaining time
                eta_seconds = batch_time.avg * (
                    num_batches - (batch_idx+1) + (max_epoch -
                                                   (epoch+1)) * num_batches
                )
                eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
                print(
                    'Epoch: [{0}/{1}][{2}/{3}]\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                    'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                    'Acc {acc.val:.2f} ({acc.avg:.2f})\t'
                    'Lr {lr:.6f}\t'
                    'eta {eta}'.format(
                        epoch + 1,
                        max_epoch,
                        batch_idx + 1,
                        num_batches,
                        batch_time=batch_time,
                        data_time=data_time,
                        loss=losses,
                        acc=accs,
                        lr=self.optimizer.param_groups[0]['lr'],
                        eta=eta_str
                    )
                )

            if writer is not None:
                n_iter = epoch*num_batches + batch_idx
                writer.add_scalar('Train/Time', batch_time.avg, n_iter)
                writer.add_scalar('Train/Data', data_time.avg, n_iter)
                writer.add_scalar('Train/Loss', losses.avg, n_iter)
                writer.add_scalar('Train/Acc', accs.avg, n_iter)
                writer.add_scalar(
                    'Train/Lr', self.optimizer.param_groups[0]['lr'], n_iter
                )

            end = time.time()

        if self.scheduler is not None:
            self.scheduler.step()
    def train(self,
              epoch,
              trainloader,
              fixbase=False,
              open_layers=None,
              print_freq=10):
        """Trains the model for one epoch on source datasets using softmax loss.

		Args:
			epoch (int): current epoch.
			trainloader (Dataloader): training dataloader.
			fixbase (bool, optional): whether to fix base layers. Default is False.
			open_layers (str or list, optional): layers open for training.
			print_freq (int, optional): print frequency. Default is 10.
		"""
        losses = AverageMeter()
        accs = AverageMeter()
        batch_time = AverageMeter()
        data_time = AverageMeter()
        p_nums = AverageMeter()
        n_nums = AverageMeter()

        self.model.train()

        if fixbase and (open_layers is not None):
            open_specified_layers(self.model, open_layers)
        else:
            open_all_layers(self.model)

        end = time.time()

        # print('QQ')
        # for batch_idx, data in enumerate(trainloader):
        #     imgs, pids = self._parse_data_for_train(data)
        #     print(pids)
        # print('QQ')
        # tensor([691, 691, 691, 691,  68,  68,  68,  68, 468, 468, 468, 468,  67,  67,
        #          67,  67, 232, 232, 232, 232, 293, 293, 293, 293, 244, 244, 244, 244,
        #          13,  13,  13,  13])
        # tensor([290, 290, 290, 290, 535, 535, 535, 535,  55,  55,  55,  55, 558, 558,
        #         558, 558, 129, 129, 129, 129, 699, 699, 699, 699, 232, 232, 232, 232,
        #         655, 655, 655, 655])
        # ...

        for batch_idx, data in enumerate(trainloader):
            data_time.update(time.time() - end)

            imgs, pids = self._parse_data_for_train(data)
            if self.use_gpu:
                imgs = imgs.cuda()
                pids = pids.cuda()

            self.optimizer.zero_grad()
            outputs = self.model(imgs)
            loss, p_num, n_num = self._compute_loss(self.criterion, outputs,
                                                    pids)
            # loss = Variable(loss, requires_grad = True)
            if loss.item() > 0:
                loss.backward()
                self.optimizer.step()

            batch_time.update(time.time() - end)

            losses.update(loss.item(), pids.size(0))
            p_nums.update(p_num)
            n_nums.update(n_num)
            accs.update(metrics.accuracy(outputs, pids)[0].item())

            if (batch_idx + 1) % print_freq == 0:
                print('Epoch: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'MS Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'P-num {p.val:.2f} ({p.avg:.2f})\t'
                      'N-num {n.val:.2f} ({n.avg:.2f})\t'
                      'Acc {acc.val:.2f} ({acc.avg:.2f})\t'.format(
                          epoch + 1,
                          batch_idx + 1,
                          len(trainloader),
                          batch_time=batch_time,
                          data_time=data_time,
                          loss=losses,
                          p=p_nums,
                          n=n_nums,
                          acc=accs))

            end = time.time()

        if (self.scheduler is not None) and (not fixbase):
            self.scheduler.step()
    def train(
            self,
            epoch,
            max_epoch,
            writer,
            print_freq=10,
            fixbase_epoch=0,
            open_layers=None,
    ):
        losses_triplet = AverageMeter()
        losses_softmax = AverageMeter()
        losses_mmd_bc = AverageMeter()
        losses_mmd_wc = AverageMeter()
        losses_mmd_global = AverageMeter()
        losses_recons = AverageMeter()
        batch_time = AverageMeter()
        data_time = AverageMeter()
        
        

        self.model.train()
        self.mgn_targetPredict.train()
       
        if (epoch + 1) <= fixbase_epoch and open_layers is not None:
            print(
                '* Only train {} (epoch: {}/{})'.format(
                    open_layers, epoch + 1, fixbase_epoch
                )
            )
            open_specified_layers(self.model, open_layers)
        else:
            open_all_layers(self.model)
            open_all_layers(self.mgn_targetPredict)
            print("All open layers!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")

        num_batches = len(self.train_loader)
        end = time.time()
       
# -------------------------------------------------------------------------------------------------------------------- #
        for batch_idx, (data, data_t) in enumerate(zip(self.train_loader, self.train_loader_t)):
            data_time.update(time.time() - end)
            

            imgs, pids = self._parse_data_for_train(data)
            imgs_clean =  imgs.clone().cuda()
            lam=0
            imgs_t, pids_t = self._parse_data_for_train(data_t)
            imagest_orig=imgs_t.cuda()
            labels=[]
            labelss=[]
            random_indexS = np.random.randint(0, imgs.size()[0])
            random_indexT = np.random.randint(0, imgs_t.size()[0])
            if epoch > 10 and epoch < 35:
                
                for i, img in enumerate(imgs):
                  
                   randmt = RandomErasing(probability=0.5,sl=0.07, sh=0.22)
                  
                   imgs[i],p = randmt(img, imgs[random_indexS])
                   labelss.append(p)
               
            if epoch >= 35:
                randmt = RandomErasing(probability=0.5,sl=0.1, sh=0.25)
                for i, img in enumerate(imgs):
                  
                   imgs[i],p = randmt(img,imgs[random_indexS])
                   labelss.append(p)

            





            
            if epoch > 10 and epoch < 35:
                randmt = RandomErasing(probability=0.5,sl=0.1, sh=0.2)
                for i, img in enumerate(imgs_t):
                   
                   imgs_t[i],p = randmt(img,imgs_t[random_indexT])
                   labels.append(p)
               
            if epoch >= 35 and epoch < 75:
                randmt = RandomErasing(probability=0.5,sl=0.2, sh=0.3)
                for i, img in enumerate(imgs_t):
                  
                   imgs_t[i],p = randmt(img,imgs_t[random_indexT])
                   labels.append(p)

            if epoch >= 75:
                randmt = RandomErasing(probability=0.5,sl=0.2, sh=0.35)
                for i, img in enumerate(imgs_t):
                   
                  
                   imgs_t[i],p = randmt(img,imgs_t[random_indexT])
                   labels.append(p)
           
            binary_labels = torch.tensor(np.asarray(labels)).cuda()
            binary_labelss = torch.tensor(np.asarray(labelss)).cuda()
            
               
            if self.use_gpu:
                imgs = imgs.cuda()
                pids = pids.cuda()
            if self.use_gpu:
                imgs_transformed = imgs_t.cuda()

            

            self.optimizer.zero_grad()
           
            imgs_clean = imgs
            outputs, output2, recons,bcc1, bocc2,bocc3 = self.model(imgs)

            occ_losss1 = self.BCE_criterion(bcc1.squeeze(1),binary_labelss.float() )
            occ_losss2 = self.BCE_criterion(bocc2.squeeze(1),binary_labelss.float() )
            occ_losss3 = self.BCE_criterion(bocc3.squeeze(1),binary_labelss.float() )

            occ_s  = occ_losss1  +occ_losss2+occ_losss3
       
           

          

            ##############CUT MIX#################################3333
            """bbx1, bby1, bbx2, bby2 = self.rand_bbox(imgs.size(), lam)
            rand_index = torch.randperm(imgs.size()[0]).cuda()
            imgs[:, :, bbx1:bbx2, bby1:bby2] = imgs[rand_index, :, bbx1:bbx2, bby1:bby2]
            targeta = pids
            targetb = pids[rand_index]"""

            ##############CUT MIX#################################3333

            outputs_t, output2_t, recons_t,bocct1, bocct2,bocct3 = self.model(imagest_orig)
            outputs_t = self.mgn_targetPredict(output2_t)
           


            loss_reconst=self.criterion_mse(recons_t, imagest_orig)
            loss_recons=self.criterion_mse(recons, imgs_clean)

         
            occ_loss1 = self.BCE_criterion(bocct1.squeeze(1),binary_labels.float() )
            occ_loss2 = self.BCE_criterion(bocct2.squeeze(1),binary_labels.float() )
            occ_loss3 = self.BCE_criterion(bocct3.squeeze(1),binary_labels.float() )
            occ_t = occ_loss1 + occ_loss2 + occ_loss3
            pids_t = pids_t.cuda()
            loss_x = self.mgn_loss(outputs, pids)
            loss_x_t = self.mgn_loss(outputs_t, pids_t)
            #loss_x_t = self._compute_loss(self.criterion_x, y, targeta)  #*lam + self._compute_loss(self.criterion_x, y, targetb)*(1-lam)
            #loss_t_t = self._compute_loss(self.criterion_t, features_t, targeta)*lam + self._compute_loss(self.criterion_t, features_t, targetb)*(1-lam)
                      
         
            if epoch > 10:

                loss_mmd_wc, loss_mmd_bc, loss_mmd_global = self._compute_loss(self.criterion_mmd, outputs[0],  outputs_t[0])
                #loss_mmd_wc1, loss_mmd_bc1, loss_mmd_global1  = self._compute_loss(self.criterion_mmd, outputs[2], outputs_t[2])
                #loss_mmd_wc3, loss_mmd_bc3, loss_mmd_global3  = self._compute_loss(self.criterion_mmd, outputs[3], outputs_t[3])
                
                #loss_mmd_wcf  = loss_mmd_wc+loss_mmd_wc1+loss_mmd_wc3
                #loss_mmd_bcf  = loss_mmd_bc+loss_mmd_bc1+loss_mmd_bc3
                #loss_mmd_globalf  = loss_mmd_global+loss_mmd_global1+loss_mmd_global3
                

                
                #print(loss_mmd_bc.item())

                l_joint =  1.5*loss_x_t  +loss_x +loss_reconst+loss_recons  #self.weight_r*loss_recons+ + loss_x + loss_t 
                #loss = loss_t + loss_x + loss_mmd_bc + loss_mmd_wc
                l_d =   0.5*loss_mmd_bc + 0.8*loss_mmd_wc    +loss_mmd_global #+loss_mmd_bc1 + loss_mmd_wc1    +loss_mmd_global1 +loss_mmd_bc3 + loss_mmd_wc3   +loss_mmd_global3
                loss =  0.3*l_d + 0.7*l_joint +0.2*occ_t + 0.1*occ_s

                

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
# -------------------------------------------------------------------------------------------------------------------- #

            batch_time.update(time.time() - end)
            #losses_triplet.update(loss_t.item(), pids.size(0))
            losses_softmax.update(loss_x_t.item(), pids.size(0))
            #losses_recons.update(loss_recons.item(), pids.size(0))
            if epoch > 10:
                losses_mmd_bc.update(loss_mmd_bc.item(), pids.size(0))
                losses_mmd_wc.update(loss_mmd_wc.item(), pids.size(0))
                losses_mmd_global.update(loss_mmd_global.item(), pids.size(0))

            if (batch_idx + 1) % print_freq == 0:
                # estimate remaining time
                eta_seconds = batch_time.avg * (
                        num_batches - (batch_idx + 1) + (max_epoch -
                                                         (epoch + 1)) * num_batches
                )
                eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
                print(
                    'Epoch: [{0}/{1}][{2}/{3}]\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    #'Loss_t {losses1.val:.4f} ({losses1.avg:.4f})\t'
                    'Loss_x {losses2.val:.4f} ({losses2.avg:.4f})\t'
                    'Loss_mmd_wc {losses3.val:.4f} ({losses3.avg:.4f})\t'
                    'Loss_mmd_bc {losses4.val:.4f} ({losses4.avg:.4f})\t'
                    'Loss_mmd_global {losses5.val:.4f} ({losses5.avg:.4f})\t'
                    #'Loss_recons {losses6.val:.4f} ({losses6.avg:.4f})\t'
                    'eta {eta}'.format(
                        epoch + 1,
                        max_epoch,
                        batch_idx + 1,
                        num_batches,
                        batch_time=batch_time,
                        #losses1=losses_triplet,
                        losses2=losses_softmax,
                        losses3=losses_mmd_wc,
                        losses4=losses_mmd_bc,
                        losses5=losses_mmd_global,
                        #losses6 = losses_recons,
                        eta=eta_str
                    )
                )
            writer = None
            if writer is not None:
                n_iter = epoch * num_batches + batch_idx
                writer.add_scalar('Train/Time', batch_time.avg, n_iter)
                writer.add_scalar('Train/Loss_triplet', losses_triplet.avg, n_iter)
                writer.add_scalar('Train/Loss_softmax', losses_softmax.avg, n_iter)
                writer.add_scalar('Train/Loss_mmd_bc', losses_mmd_bc.avg, n_iter)
                writer.add_scalar('Train/Loss_mmd_wc', losses_mmd_wc.avg, n_iter)
                writer.add_scalar('Train/Loss_mmd_global', losses_mmd_global.avg, n_iter)
                writer.add_scalar(
                    'Train/Lr', self.optimizer.param_groups[0]['lr'], n_iter
                )

            end = time.time()

        if self.scheduler is not None:
            self.scheduler.step()
        print_distri = True

        if print_distri:

            instances = self.datamanager.test_loader.query_loader.num_instances
            batch_size = self.datamanager.test_loader.batch_size
            feature_size = outputs[0].size(1) # features_t.shape[1]  # 2048
            features_t = outputs_t[0]
            features = outputs[0]
            t = torch.reshape(features_t, (int(batch_size / instances), instances, feature_size))
 
            #  and compute bc/wc euclidean distance
            bct = compute_distance_matrix(t[0], t[0])
            wct = compute_distance_matrix(t[0], t[1])
            for i in t[1:]:
                bct = torch.cat((bct, compute_distance_matrix(i, i)))
                for j in t:
                    if j is not i:
                        wct = torch.cat((wct, compute_distance_matrix(i, j)))

            s = torch.reshape(features, (int(batch_size / instances), instances, feature_size))
            bcs = compute_distance_matrix(s[0], s[0])
            wcs = compute_distance_matrix(s[0], s[1])
            for i in s[1:]:
                bcs = torch.cat((bcs, compute_distance_matrix(i, i)))
                for j in s:
                    if j is not i:
                        wcs = torch.cat((wcs, compute_distance_matrix(i, j)))

            bcs = bcs.detach()
            wcs = wcs.detach()

            b_c = [x.cpu().detach().item() for x in bcs.flatten() if x > 0.000001]
            w_c = [x.cpu().detach().item() for x in wcs.flatten() if x > 0.000001]
            data_bc = norm.rvs(b_c)
            sns.distplot(data_bc, bins='auto', fit=norm, kde=False, label='from the same class (within class)')
            data_wc = norm.rvs(w_c)
            sns.distplot(data_wc, bins='auto', fit=norm, kde=False, label='from different class (between class)')
            plt.xlabel('Euclidean distance')
            plt.ylabel('Frequency')
            plt.title('Source Domain')
            plt.legend()
            plt.savefig("Source.png")
            plt.clf()
            b_ct = [x.cpu().detach().item() for x in bct.flatten() if x > 0.1]
            w_ct = [x.cpu().detach().item() for x in wct.flatten() if x > 0.1]
            data_bc = norm.rvs(b_ct)
            sns.distplot(data_bc, bins='auto', fit=norm, kde=False, label='from the same class (within class)')
            data_wc = norm.rvs(w_ct)
            sns.distplot(data_wc, bins='auto', fit=norm, kde=False, label='from different class (between class)')
            plt.xlabel('Euclidean distance')
            plt.ylabel('Frequency')
            plt.title('Target Domain')
            plt.legend()
            plt.savefig("Target.png")