示例#1
0
def train_epoch(
    clf: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    loss_function: torch.nn.Module,
    words_train: List[List[str]],
    y_train: List[int],
    sequence_limit=32,
    batch_size=32,
    device="cpu",
) -> List[float]:
    clf.train()
    N = len(words_train)
    X, y = shuffle(words_train, y_train)
    epoch_pred = []
    losses = []
    with tqdm(range(0, N, batch_size)) as progress:
        for start in progress:
            clf.train()
            end = min(start + batch_size, N)
            X_batch = [x[:sequence_limit] for x in X[start:end]]
            y_batch = torch.tensor(y[start:end], dtype=torch.long).to(device)
            clf.zero_grad()
            y_scores = clf(X_batch)
            loss = loss_function(y_scores, y_batch)
            loss.backward()
            optimizer.step()

            clf.eval()
            epoch_pred.extend(((y_scores[:, 1] - y_scores[:, 0]) > 0).tolist())
            losses.append(loss.item())
            progress.set_description("Train Loss: {:.03}".format(
                np.mean(losses[-10:])))
    return losses
示例#2
0
def push_and_pull(opt, lnet, gnet, done, s_, bs, ba, br, gamma):
    if done:
        v_s_ = np.array([0. for i in range(len(s_))])
    else:
        _, v_s_ = lnet.forward(v_wrap(s_))
        v_s_ = v_s_.detach().numpy()

    buffer_v_target = []
    for r in br[::-1]:
        v_s_ = r + v_s_ * gamma
        buffer_v_target.append(v_s_)
    buffer_v_target = np.array(buffer_v_target).flatten('F').reshape((-1, 1))
    ba = np.array(ba).flatten('F').reshape((-1, 1))
    loss = lnet.loss_func(v_wrap(np.vstack(bs)), v_wrap(np.vstack(ba)),
                          v_wrap(np.vstack(buffer_v_target)))

    opt.zero_grad()
    loss.backward()
    for lp, gp in zip(lnet.parameters(), gnet.parameters()):
        gp._grad = lp.grad
    opt.step()

    # pull global parameters
    lnet.load_state_dict(gnet.state_dict())
    return loss
def train():
    for epoch in range(epochs):
        ts = time.time()
        print(epoch)
        for iter, (X, tar, Y) in enumerate(train_loader):
            optimizer.zero_grad()

            # inputs = X.to(computing_device)
            inputs = X.cuda()
            labels = Y.cuda()
            # labels = Y.to(computing_device)

            print("Getting outputs")
            outputs = resnet_model(inputs)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            #EARLY STOP TESTING CONDITION
            if iter > 5:
                break
            if iter % 10 == 0:
                print("epoch{}, iter{}, loss: {}".format(
                    epoch, iter, loss.item()))

        print("Finish epoch {}, time elapsed {}".format(
            epoch,
            time.time() - ts))
        #torch.save(resnet_model, 'best_model')

        #val(epoch)
        resnet_model.train()
示例#4
0
    def training_module(self):
        if self.memory_counter < self.BATCH_SIZE:
            return

        if self.learn_step_counter % 100 == 0:
            self.TARGET_NETWORK.load_state_dict(self.LOCAL_NETWORK.state_dict())
        self.learn_step_counter += 1

        index = np.random.choice(BUFFER, BATCH_SIZE)
        memory = self.memory[index, :]
        state = torch.FloatTensor(memory[:, :NUM_STATES])
        action = torch.LongTensor(memory[:, NUM_STATES:NUM_STATES + 1].astype(int))
        reward = torch.FloatTensor(memory[:, NUM_STATES + 1:NUM_STATES + 2])
        next_state = torch.FloatTensor(memory[:, -NUM_STATES:])

        Q_VALUE = self.LOCAL_NETWORK(state).gather(1, action)
        Q_NEXT = self.TARGET_NETWORK(next_state).detach()
        TARGET = reward + gamma * Q_NEXT.max(1)[0].view(BATCH_SIZE, 1)
        loss = self.loss_func(Q_VALUE, TARGET)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        if self.epsilon > self.eps_min:
            self.epsilon = self.epsilon - self.eps_dec
        else:
            self.epsilon = self.eps_min

        # print('epsilon value', self.epsilon)
        return loss
示例#5
0
def train(model, train_loader, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    model.train()
    end = time.time()

    for i, (images, target) in enumerate(train_loader):
        images = images.cuda()
        target = target.cuda()

        # compute output
        logits = model(images)
        loss = criterion(logits, target)

        # measure accuracy and record loss
        prec1 = accuracy(logits.data, target)
        n = images.size(0)
        losses.update(loss.data.item(), n)
        top1.update(prec1.item(), n)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if ((i + 1) % print_freq) == 0:
            batch_time.update(time.time() - end)
            end = time.time()
            print('Epoch {0}: [{1}/{2}]\t'
                  'Time {batch_time.val:.4f}\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(epoch, i + 1, len(train_loader), \
                                batch_time = batch_time, loss = losses, top1 = top1), flush = True)
示例#6
0
 def loss(self):
     '''This is the closure needed for the optimizer'''
     if self.pred is None or self.target is None:
         raise RuntimeError('Must call update() first')
     self.optim.zero_grad()
     loss = self.model.objective(self.pred, self.target)
     loss.backward()
     return loss
def train_val(loader=None,
              model=None,
              loss_function=None,
              optimizer=None,
              train_enable=None,
              device=None,
              model_classifier=None,
              model_id=None):
    sum_loss = 0.0
    sum_mse = 0.0
    sum_mae = 0.0
    sum_psnr = 0.0
    sum_ssim = 0.0

    if train_enable == 'True':
        model = model.train()
    else:
        model = model.eval()  # default closing Dropout

    for img_NAC, img_AC, _ in loader:

        img_NAC = img_NAC.float()
        img_NAC = img_NAC.to(device)
        img_AC = img_AC.float()
        img_AC = img_AC.to(device)

        cam = get_grad_cam(model_classifier, img_NAC)
        pred = process_cam(cam, model_id, model, img_NAC, device)

        loss = loss_function(pred, img_AC)  # Loss is just MSE
        mse, mae, psnr, ssim = matrics(img_AC, pred)

        if train_enable == 'True':
            optimizer.zero_grad()
            loss.backward()  # back propagation
            optimizer.step()

        sum_loss += float(loss.item())
        sum_mse += float(mse.item())
        sum_mae += float(mae.item())
        sum_psnr += float(psnr)
        sum_ssim += float(ssim)

    epoch_loss = sum_loss / len(loader)
    epoch_mse = sum_mse / len(loader)
    epoch_mae = sum_mae / len(loader)
    epoch_psnr = sum_psnr / len(loader)
    epoch_ssim = sum_ssim / len(loader)

    return epoch_loss, epoch_mse, epoch_mae, epoch_psnr, epoch_ssim
def train(model, loader, f_loss, optimizer, device, log_manager=None):
    """
    Train a model for one epoch, iterating over the loader
    using the f_loss to compute the loss and the optimizer
    to update the parameters of the model.

    Arguments :

        model     -- A torch.nn.Module object
        loader    -- A torch.utils.data.DataLoader
        f_loss    -- The loss function, i.e. a loss Module
        optimizer -- A torch.optim.Optimzer object
        device    -- a torch.device class specifying the device
                     used for computation

    Returns :
    """

    # We enter train mode. This is useless for the linear model
    # but is important for layers such as dropout, batchnorm, ...
    model.train()

    N = 0
    tot_loss, correct = 0.0, 0.0
    for i, (inputs, targets) in enumerate(loader):
        inputs, targets = inputs.to(device), targets.to(device)

        # Compute the forward pass through the network up to the loss
        outputs = model(inputs)

        loss = f_loss(outputs, targets)

        # if i == 0:
        #     # if final_test:
        #     print("sending image")
        #     log_manager.tensorboard_send_image(
        #         i, inputs[0], targets[0], outputs[0], txt= "trainning")
        # print("Loss: ", loss)
        N += inputs.shape[0]
        tot_loss += inputs.shape[0] * f_loss(outputs, targets).item()

        # print("Output: ", outputs)
        # predicted_targets = outputs
        # correct += (predicted_targets == targets).sum().item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return tot_loss/N, correct/N
示例#9
0
 def loss_fn(self):
     """This is the closure needed for the optimizer"""
     self.run_batch()
     self.state.optim.zero_grad()
     loss = self.state.model.objective(self.quant, self.target)
     inputs = (self.mel_enc_input, self.state.model.encoding_bn)
     mel_grad, bn_grad = torch.autograd.grad(loss,
                                             inputs,
                                             retain_graph=True)
     self.state.model.objective.metrics.update({
         'mel_grad_sd': mel_grad.std(),
         'bn_grad_sd': bn_grad.std()
     })
     # loss.backward(create_graph=True, retain_graph=True)
     loss.backward()
     return loss
示例#10
0
def train(args):
    transformer=T.Compose([
    T.ToTensor(),
    T.Normalize((0.3081),(0.1307))
    ])
    train_data=torchvision.datasets.MNIST(root=args.data_path,transform=transformer,download=True,train=True)
    train_loader=torch.utils.data.DataLoader(train_data,batch_size=args.batch_size,shuffle=True,drop_last=True,num_workers=4)

    model_arg=model_dict[args.model][1]
    model_arg["act"]=(act_dict[args.act])
    device=torch.device(args.device)
    net=model_dict[args.model][0](**model_arg).to(device)

    if args.optimizer=='adam':
        optimizer=torch.optim.Adam(net.parameters(),lr=args.lr,betas=(0.9,0.99))
    elif args.optimizer=='SGD':
        optimizer=torch.optim.SGD(net.parameters(),lr=args.lr,momentum=0.9)
    else:
        optimizer=None
    loss_func=loss_dict[args.loss_func]()
    
    writer=tensorboardX.SummaryWriter()

    current_acc=0
    for epoch in range(args.epoch):
        total_loss=0.
        total_acc=0.
        for i,(images,labels) in enumerate(train_loader):
            images,labels=images.to(device),labels.to(device)
            outputs=net(images)
            loss=loss_func(outputs,labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss+=loss.item()
            acc=torch.sum(outputs.argmax(-1)==labels).item()
            total_acc+=acc/args.batch_size

            writer.add_scalar('data/loss',loss,i+epoch*len(train_loader))
        print("epoch%3d: loss=%.4f ,acc:%.2f%% " %(epoch,total_loss/len(train_loader),total_acc*100/len(train_loader)))
        if(epoch%1==0):
            eval_acc=eval(args,net)
            writer.add_scalar('data/acc',eval_acc,epoch)
            if eval_acc>current_acc:
                torch.save(net.state_dict(),'%s/best_%s_model.pth' %(args.checkpoints_path,args.model))
示例#11
0
    def trainloop(self, n_epochs):
        for epoch in range(1, n_epochs + 1):
            self.evaluate(mask_data.data, mask_data.label)
            loss_train = 0.0
            for input, realout in self.dataloader:
                predictout = self.network(input)

                loss = self.loss_fn(predictout, realout)

                self.optim.zero_grad()

                loss.backward()
                self.optim.step()
                loss_train += loss.item()
            #if epoch == 1 or epoch % 100 == 0:
            print(
                f'{datetime.datetime.now()} epoch {epoch} training loss {loss_train/len(self.dataloader)}'
            )
示例#12
0
def train(
    model: Model,
    device: Device,
    loader: DataLoader,
    optimizer: Optimizer,
    loss_function: Criterion,
    epoch: int,
    log: Logger,
    writer: Optional[SummaryWriter] = None,
    scheduler: Optional[Scheduler] = None,
) -> Tuple[float, float]:
    """
    Training loop
    :param model: PyTorch model to test
    :param device: torch.device or str, where to perform computations
    :param loader: PyTorch DataLoader over test dataset
    :param optimizer: PyTorch Optimizer bounded with model
    :param loss_function: criterion
    :param epoch: epoch id
    :param writer: tensorboard SummaryWriter
    :param log: Logger
    :param scheduler: optional PyTorch Scheduler
    :return: tuple(train loss, train accuracy)
    """
    model.train()
    model.to(device)

    meter_loss = Meter("loss")
    meter_corr = Meter("acc")

    batch_size = len(loader.dataset) / len(loader)
    tqdm_loader = tqdm(loader, desc=f"train epoch {epoch:03d}")
    for batch_idx, batch_data in enumerate(tqdm_loader):
        data, target = batch_data.images.to(device), batch_data.labels.to(
            device)
        optimizer.zero_grad()

        output = model(data)

        loss = loss_function(output, target)
        loss.backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step()

        pred = output.argmax(dim=1, keepdim=True)
        # Display training status
        meter_loss.add(loss.item())
        meter_corr.add(pred.eq(target.view_as(pred)).sum().item())
        tqdm_loader.set_postfix({
            "loss": meter_loss.avg,
            "acc": 100 * meter_corr.avg / batch_size,
            "lr": scheduler.get_lr(),
        })

    # Log in file and tensorboard
    acc = 100.0 * meter_corr.sum / len(loader.dataset)
    log.info("Train Epoch: {} [ ({:.0f}%)]\tLoss: {:.6f}".format(
        epoch, acc, meter_loss.avg))
    if writer is not None:
        writer.add_scalar("train_loss", loss.item(), global_step=epoch)
        writer.add_scalar("train_acc", acc, global_step=epoch)

    return meter_loss.avg, acc
示例#13
0
    def train(self,
              epoch,
              max_epoch,
              writer,
              print_freq=10,
              fixbase_epoch=0,
              open_layers=None):
        losses_t = AverageMeter()
        losses_x = AverageMeter()
        accs = AverageMeter()
        batch_time = AverageMeter()
        data_time = AverageMeter()

        loss_meter = AverageMeter()
        criterion_pcb = torch.nn.CrossEntropyLoss()

        self.model.train()
        if (epoch + 1) <= fixbase_epoch and open_layers is not None:
            print('* Only train {} (epoch: {}/{})'.format(
                open_layers, epoch + 1, fixbase_epoch))
            open_specified_layers(self.model, open_layers)
        else:
            open_all_layers(self.model)

        num_batches = len(self.train_loader)
        end = time.time()

        layer_nums = 3
        for batch_idx, data in enumerate(self.train_loader):
            data_time.update(time.time() - end)

            imgs, pids = self._parse_data_for_train(data)

            if self.use_gpu:
                imgs = imgs.cuda()
                pids = pids.cuda()
            self.optimizer.zero_grad()
            outputs, features, h, b, logits_list, b_classify = self.model(imgs)
            #print(len(logits_list))
            #print(logits_list[0].shape)
            pids_g = self.parse_pids(pids)
            x = features

            target_b = F.cosine_similarity(b[:pids_g.size(0) // 2],
                                           b[pids_g.size(0) // 2:])
            target_x = F.cosine_similarity(x[:pids_g.size(0) // 2],
                                           x[pids_g.size(0) // 2:])

            loss1 = F.mse_loss(target_b, target_x)
            loss2 = torch.mean(
                torch.abs(
                    torch.pow(
                        torch.abs(h) - Variable(torch.ones(h.size()).cuda()),
                        3)))
            loss_greedy = loss1 + 0.1 * loss2
            loss_batchhard_hash = self.compute_hashbatchhard(b, pids)

            loss_t = self._compute_loss(self.criterion_t, features, pids)
            loss_x = self._compute_loss(self.criterion_x, outputs,
                                        pids) + self._compute_loss(
                                            self.criterion_x, b_classify, pids)

            #print([criterion_pcb(logits, pids) for logits in logits_list])
            id_loss = 0  # torch.sum(
            #torch.cat([criterion_pcb(logits, pids) for logits in logits_list]))
            for logits in logits_list:
                loss_tmp = criterion_pcb(logits, pids)
                id_loss += loss_tmp
            #print(id_loss)
            if layer_nums >= 1:
                loss1 = criterion_pcb(logits_list[0], pids)
                loss2 = criterion_pcb(logits_list[1], pids)
                loss3 = criterion_pcb(logits_list[2], pids)
                loss4 = criterion_pcb(logits_list[3], pids)
                loss5 = criterion_pcb(logits_list[4], pids)
                loss6 = criterion_pcb(logits_list[5], pids)
                if layer_nums >= 2:
                    loss12 = criterion_pcb(logits_list[6], pids)
                    loss23 = criterion_pcb(logits_list[7], pids)
                    loss34 = criterion_pcb(logits_list[8], pids)
                    loss45 = criterion_pcb(logits_list[9], pids)
                    loss56 = criterion_pcb(logits_list[10], pids)
                    if layer_nums >= 3:
                        loss123 = criterion_pcb(logits_list[11], pids)
                        loss234 = criterion_pcb(logits_list[12], pids)
                        loss345 = criterion_pcb(logits_list[13], pids)
                        loss456 = criterion_pcb(logits_list[14], pids)
            var = torch.zeros(1)
            var = var.cuda()
            metric_loss = torch.max(
                loss12 - loss1.detach(), Variable(var)
            ) + torch.max(loss12 - loss2.detach(), Variable(var)) + torch.max(
                loss23 - loss2.detach(), Variable(var)
            ) + torch.max(loss23 - loss3.detach(), Variable(var)) + torch.max(
                loss34 - loss3.detach(), Variable(var)) + torch.max(
                    loss34 - loss4.detach(), Variable(var)) + torch.max(
                        loss45 - loss4.detach(), Variable(var)) + torch.max(
                            loss45 - loss5.detach(),
                            Variable(var)) + torch.max(
                                loss56 - loss5.detach(),
                                Variable(var)) + torch.max(
                                    loss56 - loss6.detach(), Variable(var))

            loss_pcb_record = id_loss + metric_loss
            loss = id_loss + metric_loss + self.weight_t * loss_t + self.weight_x * loss_x + loss_greedy + loss_batchhard_hash * 2

            loss.backward()
            self.optimizer.step()

            batch_time.update(time.time() - end)

            losses_t.update(loss_t.item(), pids.size(0))
            losses_x.update(loss_x.item(), pids.size(0))

            accs.update(metrics.accuracy(outputs, pids)[0].item())
            loss_meter.update(to_scalar(loss_pcb_record))
            if (batch_idx + 1) % print_freq == 0:
                # estimate remaining time
                eta_seconds = batch_time.avg * (num_batches - (batch_idx + 1) +
                                                (max_epoch -
                                                 (epoch + 1)) * num_batches)
                eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
                print('Epoch: [{0}/{1}][{2}/{3}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss_t {loss_t.val:.4f} ({loss_t.avg:.4f})\t'
                      'Loss_x {loss_x.val:.4f} ({loss_x.avg:.4f})\t'
                      'Loss_g {loss_g:.4f} )\t'
                      'Loss_p {loss_p:.4f} )\t'
                      'Loss_pcb {loss_pcb:.4f} )\t'
                      'Acc {acc.val:.2f} ({acc.avg:.2f})\t'
                      'Lr {lr:.6f}\t'
                      'eta {eta}'.format(
                          epoch + 1,
                          max_epoch,
                          batch_idx + 1,
                          num_batches,
                          batch_time=batch_time,
                          data_time=data_time,
                          loss_t=losses_t,
                          loss_x=losses_x,
                          loss_g=loss_greedy,
                          loss_p=loss_batchhard_hash,
                          loss_pcb=loss_meter.val,
                          acc=accs,
                          lr=self.optimizer.param_groups[0]['lr'],
                          eta=eta_str))

            if writer is not None:
                n_iter = epoch * num_batches + batch_idx
                writer.add_scalar('Train/Time', batch_time.avg, n_iter)
                writer.add_scalar('Train/Data', data_time.avg, n_iter)
                writer.add_scalar('Train/Loss_t', losses_t.avg, n_iter)
                writer.add_scalar('Train/Loss_x', losses_x.avg, n_iter)
                writer.add_scalar('Train/Acc', accs.avg, n_iter)
                writer.add_scalar('Train/Lr',
                                  self.optimizer.param_groups[0]['lr'], n_iter)

            end = time.time()

        if self.scheduler is not None:
            self.scheduler.step()
    def train(
            self,
            epoch,
            max_epoch,
            writer,
            print_freq=10,
            fixbase_epoch=0,
            open_layers=None,
    ):
        losses_triplet = AverageMeter()
        losses_softmax = AverageMeter()
        losses_mmd_bc = AverageMeter()
        losses_mmd_wc = AverageMeter()
        losses_mmd_global = AverageMeter()
        losses_recons = AverageMeter()
        batch_time = AverageMeter()
        data_time = AverageMeter()
        
        

        self.model.train()
        self.mgn_targetPredict.train()
       
        if (epoch + 1) <= fixbase_epoch and open_layers is not None:
            print(
                '* Only train {} (epoch: {}/{})'.format(
                    open_layers, epoch + 1, fixbase_epoch
                )
            )
            open_specified_layers(self.model, open_layers)
        else:
            open_all_layers(self.model)
            open_all_layers(self.mgn_targetPredict)
            print("All open layers!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")

        num_batches = len(self.train_loader)
        end = time.time()
       
# -------------------------------------------------------------------------------------------------------------------- #
        for batch_idx, (data, data_t) in enumerate(zip(self.train_loader, self.train_loader_t)):
            data_time.update(time.time() - end)
            

            imgs, pids = self._parse_data_for_train(data)
            imgs_clean =  imgs.clone().cuda()
            lam=0
            imgs_t, pids_t = self._parse_data_for_train(data_t)
            imagest_orig=imgs_t.cuda()
            labels=[]
            labelss=[]
            random_indexS = np.random.randint(0, imgs.size()[0])
            random_indexT = np.random.randint(0, imgs_t.size()[0])
            if epoch > 10 and epoch < 35:
                
                for i, img in enumerate(imgs):
                  
                   randmt = RandomErasing(probability=0.5,sl=0.07, sh=0.22)
                  
                   imgs[i],p = randmt(img, imgs[random_indexS])
                   labelss.append(p)
               
            if epoch >= 35:
                randmt = RandomErasing(probability=0.5,sl=0.1, sh=0.25)
                for i, img in enumerate(imgs):
                  
                   imgs[i],p = randmt(img,imgs[random_indexS])
                   labelss.append(p)

            





            
            if epoch > 10 and epoch < 35:
                randmt = RandomErasing(probability=0.5,sl=0.1, sh=0.2)
                for i, img in enumerate(imgs_t):
                   
                   imgs_t[i],p = randmt(img,imgs_t[random_indexT])
                   labels.append(p)
               
            if epoch >= 35 and epoch < 75:
                randmt = RandomErasing(probability=0.5,sl=0.2, sh=0.3)
                for i, img in enumerate(imgs_t):
                  
                   imgs_t[i],p = randmt(img,imgs_t[random_indexT])
                   labels.append(p)

            if epoch >= 75:
                randmt = RandomErasing(probability=0.5,sl=0.2, sh=0.35)
                for i, img in enumerate(imgs_t):
                   
                  
                   imgs_t[i],p = randmt(img,imgs_t[random_indexT])
                   labels.append(p)
           
            binary_labels = torch.tensor(np.asarray(labels)).cuda()
            binary_labelss = torch.tensor(np.asarray(labelss)).cuda()
            
               
            if self.use_gpu:
                imgs = imgs.cuda()
                pids = pids.cuda()
            if self.use_gpu:
                imgs_transformed = imgs_t.cuda()

            

            self.optimizer.zero_grad()
           
            imgs_clean = imgs
            outputs, output2, recons,bcc1, bocc2,bocc3 = self.model(imgs)

            occ_losss1 = self.BCE_criterion(bcc1.squeeze(1),binary_labelss.float() )
            occ_losss2 = self.BCE_criterion(bocc2.squeeze(1),binary_labelss.float() )
            occ_losss3 = self.BCE_criterion(bocc3.squeeze(1),binary_labelss.float() )

            occ_s  = occ_losss1  +occ_losss2+occ_losss3
       
           

          

            ##############CUT MIX#################################3333
            """bbx1, bby1, bbx2, bby2 = self.rand_bbox(imgs.size(), lam)
            rand_index = torch.randperm(imgs.size()[0]).cuda()
            imgs[:, :, bbx1:bbx2, bby1:bby2] = imgs[rand_index, :, bbx1:bbx2, bby1:bby2]
            targeta = pids
            targetb = pids[rand_index]"""

            ##############CUT MIX#################################3333

            outputs_t, output2_t, recons_t,bocct1, bocct2,bocct3 = self.model(imagest_orig)
            outputs_t = self.mgn_targetPredict(output2_t)
           


            loss_reconst=self.criterion_mse(recons_t, imagest_orig)
            loss_recons=self.criterion_mse(recons, imgs_clean)

         
            occ_loss1 = self.BCE_criterion(bocct1.squeeze(1),binary_labels.float() )
            occ_loss2 = self.BCE_criterion(bocct2.squeeze(1),binary_labels.float() )
            occ_loss3 = self.BCE_criterion(bocct3.squeeze(1),binary_labels.float() )
            occ_t = occ_loss1 + occ_loss2 + occ_loss3
            pids_t = pids_t.cuda()
            loss_x = self.mgn_loss(outputs, pids)
            loss_x_t = self.mgn_loss(outputs_t, pids_t)
            #loss_x_t = self._compute_loss(self.criterion_x, y, targeta)  #*lam + self._compute_loss(self.criterion_x, y, targetb)*(1-lam)
            #loss_t_t = self._compute_loss(self.criterion_t, features_t, targeta)*lam + self._compute_loss(self.criterion_t, features_t, targetb)*(1-lam)
                      
         
            if epoch > 10:

                loss_mmd_wc, loss_mmd_bc, loss_mmd_global = self._compute_loss(self.criterion_mmd, outputs[0],  outputs_t[0])
                #loss_mmd_wc1, loss_mmd_bc1, loss_mmd_global1  = self._compute_loss(self.criterion_mmd, outputs[2], outputs_t[2])
                #loss_mmd_wc3, loss_mmd_bc3, loss_mmd_global3  = self._compute_loss(self.criterion_mmd, outputs[3], outputs_t[3])
                
                #loss_mmd_wcf  = loss_mmd_wc+loss_mmd_wc1+loss_mmd_wc3
                #loss_mmd_bcf  = loss_mmd_bc+loss_mmd_bc1+loss_mmd_bc3
                #loss_mmd_globalf  = loss_mmd_global+loss_mmd_global1+loss_mmd_global3
                

                
                #print(loss_mmd_bc.item())

                l_joint =  1.5*loss_x_t  +loss_x +loss_reconst+loss_recons  #self.weight_r*loss_recons+ + loss_x + loss_t 
                #loss = loss_t + loss_x + loss_mmd_bc + loss_mmd_wc
                l_d =   0.5*loss_mmd_bc + 0.8*loss_mmd_wc    +loss_mmd_global #+loss_mmd_bc1 + loss_mmd_wc1    +loss_mmd_global1 +loss_mmd_bc3 + loss_mmd_wc3   +loss_mmd_global3
                loss =  0.3*l_d + 0.7*l_joint +0.2*occ_t + 0.1*occ_s

                

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
# -------------------------------------------------------------------------------------------------------------------- #

            batch_time.update(time.time() - end)
            #losses_triplet.update(loss_t.item(), pids.size(0))
            losses_softmax.update(loss_x_t.item(), pids.size(0))
            #losses_recons.update(loss_recons.item(), pids.size(0))
            if epoch > 10:
                losses_mmd_bc.update(loss_mmd_bc.item(), pids.size(0))
                losses_mmd_wc.update(loss_mmd_wc.item(), pids.size(0))
                losses_mmd_global.update(loss_mmd_global.item(), pids.size(0))

            if (batch_idx + 1) % print_freq == 0:
                # estimate remaining time
                eta_seconds = batch_time.avg * (
                        num_batches - (batch_idx + 1) + (max_epoch -
                                                         (epoch + 1)) * num_batches
                )
                eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
                print(
                    'Epoch: [{0}/{1}][{2}/{3}]\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    #'Loss_t {losses1.val:.4f} ({losses1.avg:.4f})\t'
                    'Loss_x {losses2.val:.4f} ({losses2.avg:.4f})\t'
                    'Loss_mmd_wc {losses3.val:.4f} ({losses3.avg:.4f})\t'
                    'Loss_mmd_bc {losses4.val:.4f} ({losses4.avg:.4f})\t'
                    'Loss_mmd_global {losses5.val:.4f} ({losses5.avg:.4f})\t'
                    #'Loss_recons {losses6.val:.4f} ({losses6.avg:.4f})\t'
                    'eta {eta}'.format(
                        epoch + 1,
                        max_epoch,
                        batch_idx + 1,
                        num_batches,
                        batch_time=batch_time,
                        #losses1=losses_triplet,
                        losses2=losses_softmax,
                        losses3=losses_mmd_wc,
                        losses4=losses_mmd_bc,
                        losses5=losses_mmd_global,
                        #losses6 = losses_recons,
                        eta=eta_str
                    )
                )
            writer = None
            if writer is not None:
                n_iter = epoch * num_batches + batch_idx
                writer.add_scalar('Train/Time', batch_time.avg, n_iter)
                writer.add_scalar('Train/Loss_triplet', losses_triplet.avg, n_iter)
                writer.add_scalar('Train/Loss_softmax', losses_softmax.avg, n_iter)
                writer.add_scalar('Train/Loss_mmd_bc', losses_mmd_bc.avg, n_iter)
                writer.add_scalar('Train/Loss_mmd_wc', losses_mmd_wc.avg, n_iter)
                writer.add_scalar('Train/Loss_mmd_global', losses_mmd_global.avg, n_iter)
                writer.add_scalar(
                    'Train/Lr', self.optimizer.param_groups[0]['lr'], n_iter
                )

            end = time.time()

        if self.scheduler is not None:
            self.scheduler.step()
        print_distri = True

        if print_distri:

            instances = self.datamanager.test_loader.query_loader.num_instances
            batch_size = self.datamanager.test_loader.batch_size
            feature_size = outputs[0].size(1) # features_t.shape[1]  # 2048
            features_t = outputs_t[0]
            features = outputs[0]
            t = torch.reshape(features_t, (int(batch_size / instances), instances, feature_size))
 
            #  and compute bc/wc euclidean distance
            bct = compute_distance_matrix(t[0], t[0])
            wct = compute_distance_matrix(t[0], t[1])
            for i in t[1:]:
                bct = torch.cat((bct, compute_distance_matrix(i, i)))
                for j in t:
                    if j is not i:
                        wct = torch.cat((wct, compute_distance_matrix(i, j)))

            s = torch.reshape(features, (int(batch_size / instances), instances, feature_size))
            bcs = compute_distance_matrix(s[0], s[0])
            wcs = compute_distance_matrix(s[0], s[1])
            for i in s[1:]:
                bcs = torch.cat((bcs, compute_distance_matrix(i, i)))
                for j in s:
                    if j is not i:
                        wcs = torch.cat((wcs, compute_distance_matrix(i, j)))

            bcs = bcs.detach()
            wcs = wcs.detach()

            b_c = [x.cpu().detach().item() for x in bcs.flatten() if x > 0.000001]
            w_c = [x.cpu().detach().item() for x in wcs.flatten() if x > 0.000001]
            data_bc = norm.rvs(b_c)
            sns.distplot(data_bc, bins='auto', fit=norm, kde=False, label='from the same class (within class)')
            data_wc = norm.rvs(w_c)
            sns.distplot(data_wc, bins='auto', fit=norm, kde=False, label='from different class (between class)')
            plt.xlabel('Euclidean distance')
            plt.ylabel('Frequency')
            plt.title('Source Domain')
            plt.legend()
            plt.savefig("Source.png")
            plt.clf()
            b_ct = [x.cpu().detach().item() for x in bct.flatten() if x > 0.1]
            w_ct = [x.cpu().detach().item() for x in wct.flatten() if x > 0.1]
            data_bc = norm.rvs(b_ct)
            sns.distplot(data_bc, bins='auto', fit=norm, kde=False, label='from the same class (within class)')
            data_wc = norm.rvs(w_ct)
            sns.distplot(data_wc, bins='auto', fit=norm, kde=False, label='from different class (between class)')
            plt.xlabel('Euclidean distance')
            plt.ylabel('Frequency')
            plt.title('Target Domain')
            plt.legend()
            plt.savefig("Target.png")
示例#15
0
    def train(
        self,
        epoch,
        max_epoch,
        writer,
        print_freq=10,
        fixbase_epoch=0,
        open_layers=None
    ):
        losses_t = AverageMeter()
        losses_x = AverageMeter()
        accs = AverageMeter()
        accs_b = AverageMeter()
        batch_time = AverageMeter()
        data_time = AverageMeter()

        center_loss=CenterLoss(num_classes=751, feat_dim=4608)
        #center_loss_h=CenterLoss(num_classes=751, feat_dim=256)
        self.model.train()
        if (epoch + 1) <= fixbase_epoch and open_layers is not None:
            print(
                '* Only train {} (epoch: {}/{})'.format(
                    open_layers, epoch + 1, fixbase_epoch
                )
            )
            open_specified_layers(self.model, open_layers)
        else:
            open_all_layers(self.model)

        num_batches = len(self.train_loader)
        end = time.time()
        
        
        layer_nums=3
        for batch_idx, data in enumerate(self.train_loader):
            data_time.update(time.time() - end)

            imgs, pids = self._parse_data_for_train(data)

            if self.use_gpu:
                imgs = imgs.cuda()
                pids = pids.cuda()
            self.optimizer.zero_grad()
            outputs, features,h, b,cls_score,b_classify= self.model(imgs)
            #print(len(logits_list))
            #print(logits_list[0].shape)	
            pids_g=self.parse_pids(pids,self.datamanager.num_train_pids)
            x=features
            pids_ap=pids_g.cuda()
            #AP_loss=aploss_criterion(b_classify,pids_ap)
            
            target_b = F.cosine_similarity(b[:pids_g.size(0) // 2], b[pids_g.size(0) // 2:])
            target_x = F.cosine_similarity(x[:pids_g.size(0) // 2], x[pids_g.size(0) // 2:])
            
            

            loss1 = F.mse_loss(target_b, target_x)
            loss2 = torch.mean(torch.abs(torch.pow(torch.abs(h) - Variable(torch.ones(h.size()).cuda()), 3)))
            loss_greedy = loss1 + 0.1 * loss2
            loss_batchhard_hash=self.compute_hashbatchhard(b,pids)
 
            #print(features.shape)
            loss_t = self._compute_loss(self.criterion_t, features, pids)#+self._compute_loss(self.criterion_t,b,pids)
            loss_x = self._compute_loss(self.criterion_x, outputs, pids)+self._compute_loss(self.criterion_x, b_classify, pids)+self._compute_loss(self.criterion_x, cls_score, pids)
            
            centerloss=0#center_loss(features,pids)#+center_loss_h(h,pids)
            centerloss=centerloss*0.0005

            #print(centerloss)
            loss =centerloss+self.weight_t * loss_t + self.weight_x * loss_x+loss_greedy+loss_batchhard_hash*2#+AP_loss
#            loss =centerloss + self.weight_x * loss_x+loss_greedy+loss_batchhard_hash*2#+AP_loss

            
            
            
            
            
            
            loss.backward()
            self.optimizer.step()

            batch_time.update(time.time() - end)

            #losses_t.update(loss_t.item(), pids.size(0))
            losses_t.update(loss_t.item(), pids.size(0))
            losses_x.update(loss_x.item(), pids.size(0))

            accs.update(metrics.accuracy(outputs, pids)[0].item())
            accs_b.update(metrics.accuracy(b_classify, pids)[0].item())
            if (batch_idx+1) % print_freq == 0:
                # estimate remaining time
                eta_seconds = batch_time.avg * (
                    num_batches - (batch_idx+1) + (max_epoch -
                                                   (epoch+1)) * num_batches
                )
                eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
                print(
                    'Epoch: [{0}/{1}][{2}/{3}]\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                    'Loss_t {loss_t.val:.4f} ({loss_t.avg:.4f})\t'
                    'Loss_x {loss_x.val:.4f} ({loss_x.avg:.4f})\t'
                    'Loss_g {loss_g:.4f} )\t'
                    'Loss_p {loss_p:.4f} )\t'
                    'Loss_cl {loss_cl:.4f} )\t'
                    'Acc {acc.val:.2f} ({acc.avg:.2f})\t'
                    'Acc_b {acc_b.val:.2f} ({acc_b.avg:.2f})\t'
                    'Lr {lr:.6f}\t'
                    'eta {eta}'.format(
                        epoch + 1,
                        max_epoch,
                        batch_idx + 1,
                        num_batches,
                        batch_time=batch_time,
                        data_time=data_time,
                        loss_t=losses_t,
                        loss_x=losses_x,
                        loss_g=loss_greedy,
                        loss_p=loss_batchhard_hash,
                        loss_cl=centerloss,
                        #loss_ap=AP_loss,
                        acc=accs,
                        acc_b=accs_b,
                        lr=self.optimizer.param_groups[0]['lr'],
                        eta=eta_str
                    )
                )

            if writer is not None:
                n_iter = epoch*num_batches + batch_idx
                writer.add_scalar('Train/Time', batch_time.avg, n_iter)
                writer.add_scalar('Train/Data', data_time.avg, n_iter)
                writer.add_scalar('Train/Loss_t', losses_t.avg, n_iter)
                writer.add_scalar('Train/Loss_x', losses_x.avg, n_iter)
                writer.add_scalar('Train/Acc', accs.avg, n_iter)
                writer.add_scalar(
                    'Train/Lr', self.optimizer.param_groups[0]['lr'], n_iter
                )

            end = time.time()

        if self.scheduler is not None:
            self.scheduler.step()
示例#16
0
    def train(self):
        self.model.train()

        # log data to these variables
        if 'loaded' in self.settings:
            if not self.settings['loaded']:
                self.model.training_loss = []
                self.model.training_acc = []
                self.model.validation_acc = []
                self.model.validation_loss = []
        else:
            self.model.training_loss = []
            self.model.training_acc = []
            self.model.validation_acc = []
            self.model.validation_loss = []

      

        for epoch in range(self.settings['EPOCHS']):
            self.model.train()
            ts = time.time()
            lossSum = 0
            accuracySum = 0
            totalImage = 0
            for iter, (X, tar, Y) in enumerate(self.train_loader):

                
                self.optimizer.zero_grad()

                if('imagesPerEpoch' in self.settings):
                    if iter*self.batch_size > self.settings['imagesPerEpoch']:
                        break
                

                #inputs = X.to(computing_device)
                inputs = X.cuda()
                labels = Y.cuda()
                #labels = Y.to(computing_device)

                outputs = self.model(inputs)

                loss = self.criterion(outputs, labels)

                lossSum += loss.item()

                accuracies = pixel_acc(outputs, labels)

                accuracySum += torch.sum(accuracies)/self.batch_size

                torch.cuda.empty_cache()
                
                loss.backward()
                self.optimizer.step()

                

                totalImage += 1

                if iter % 100 == 0:
                    None
                    print("Iter", iter, "Done")
                    #print("epoch{}, iter{}, loss: {}".format(epoch, iter, loss.item()))
            lossSum = lossSum / totalImage

            self.model.training_loss.append(lossSum)
            accuracy = accuracySum / totalImage # totalImage?
            if accuracy is None:
                accuracy = torch.tensor([0.0])
            self.model.training_acc.append(accuracy.item())
            print(totalImage*self.batch_size)
            print("-------------------------------------")
            print("Train epoch {}, time elapsed {}, loss {}, accuracy: {}".format(epoch, time.time() - ts, lossSum, accuracy.item()))


            self.val(epoch)
示例#17
0
    def compute_acer_loss(self,
                          policies,
                          q_values,
                          values,
                          actions,
                          rewards,
                          retrace,
                          masks,
                          behavior_policies,
                          entropy_weight=None,
                          gamma=None,
                          inner_update=False,
                          truncation_clip=10,
                          loss_list=None):
        entropy_weight = self.entropy_weight
        gamma = self.gamma
        loss = 0

        for step in reversed(range(len(rewards))):
            importance_weight = policies[step].detach(
            ) / behavior_policies[step].detach()

            assert sum(masks[step]) / len(masks[step]) != 1.0 or sum(
                masks[step]) / len(masks[step]) != 0.0
            retrace = rewards[step].view(
                -1, 1) + gamma * retrace * sum(masks[step]) / len(masks[step])
            advantage = retrace - values[step]

            log_policy_action = policies[step].gather(
                1, actions[step].view(-1, 1)).log()
            assert log_policy_action.shape == actions[step].view(-1,1).shape, \
                f"log_policy_action.shape : {log_policy_action.shape},  actions[step].view(-1,1).shape : {actions[step].view(-1,1).shape}"

            truncated_importance_weight = importance_weight.gather(
                1, actions[step].view(-1, 1)).clamp(max=truncation_clip)
            assert truncated_importance_weight.shape == actions[step].view(-1,1).shape, \
                f"truncated_importance_weight.shape : {truncated_importance_weight.shape}, actions[step].view(-1,1).shape : {actions[step].view(-1,1).shape}"

            actor_loss = -(truncated_importance_weight * log_policy_action *
                           advantage.detach()).mean(0)
            correction_weight = (1 -
                                 truncation_clip / importance_weight).clamp(
                                     min=0)
            actor_loss -= (correction_weight * policies[step].log() *
                           policies[step]).sum(1).mean(0)

            entropy = entropy_weight * -(policies[step].log() *
                                         policies[step]).sum(1).mean(0)

            q_value = q_values[step].gather(1, actions[step].view(-1, 1))
            #critic_loss = ((retrace - q_value) ** 2 / 2).mean()
            critic_loss = nn.MSELoss()(retrace, q_value)

            truncated_rho = importance_weight.gather(1, actions[step].view(
                -1, 1)).clamp(max=1)
            assert truncated_rho.shape == actions[step].view(-1,1).shape, \
                f'truncated_rho.shape : {truncated_rho.shape}, actions[step].view(-1,1).shape : {actions[step].view(-1,1).shape}'

            retrace = truncated_rho * (
                retrace - q_value.detach()) + values[step].detach()

            loss += actor_loss + critic_loss - entropy

        if inner_update:
            return loss
        else:
            self.feature_encoder_optim.zero_grad()
            self.model_optim.zero_grad()

            loss.backward()

            self.feature_encoder_optim.step()
            self.model_optim.step()
示例#18
0
def fit(model, train_dataset, device, epoch=0, image_index=0, optimizer=None):
    if (optimizer == None):
        print('instantiating optimizer')
        optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
        print('optimizer instantiated')

    criterion = nn.CrossEntropyLoss()  # Error function: cross entropy
    # Stochastic gradient descent
    # Initial learning rate: 0.001
    # momentum: 0.9

    running_loss = 1.0
    images_since_last_save = 0

    #Runs training for two epochs (iterates over the 87 thousand images twice)
    while epoch < 2 or running_loss < 10e-3:
        running_loss = 0.0
        print('epoch', epoch)
        model.train(
        )  # Sets a flag indicating the code that follows performs training
        # this makes sure training functionality like dropout and batch normalization perform
        #as expected

        print('loading new batch')
        batch_start = timer()

        for index, (samples, labels) in enumerate(train_dataset):
            batch_end = timer()
            print('batch loaded. time elapsed: ', batch_end - batch_start)
            if (image_index % 1000 == 0):
                print('current image:', image_index)
            # the variable data contains an entire batch of inputs and their associated labels

            #samples, labels = data
            #print('sending data to device')
            device_start = timer()
            samples, labels = samples.to(device), labels.to(
                device)  # Sends the data to the GPU
            device_end = timer()
            #print('data sent. elapsed time', device_end-device_start)

            #print("zeroing grad")
            optimizer.zero_grad(
            )  # Zeroes the gradient, otherwise it will accumulate at every iteration
            # the result would be that the network would start taking huge parameter jumps as training went on
            #print('grad zeroed')

            #print('inferring...')
            infer_start = timer()
            output = model(samples)[:, :, :800, :
                                    800]  # Forward passes the input data
            infer_end = timer()
            #print('inferred')
            #print('time elapsed during inference:', infer_end - infer_start)

            #print('computing loss')
            loss_start = timer()
            loss = criterion(output, labels)  # Computes the error
            loss.backward(
            )  # Computes the gradient, yielding how much each parameter must be updated
            loss_end = timer()

            #print('updating weights')
            weights_start = timer()
            optimizer.step(
            )  # Updates each parameter according to the gradient
            weights_end = timer()
            #print('weights updated. time elapsed: ', weights_end-weights_start)

            running_loss = loss.item()
            print('running loss', running_loss)
            '''if index % 10 == 9:
                print('[%d %5d] loss %.3f' % (epoch + 1, index + 1, running_loss / 2000))
                running_loss = 0.0'''
            #print('loading new batch')
            batch_start = timer()

            image_index += samples.size()[0]

            images_since_last_save += samples.size()[0]
            if (images_since_last_save > 500):
                print('saving checkpoint at image', image_index)
                save_model(
                    model, epoch, image_index, optimizer, 'customfcn_' +
                    str(epoch) + '_' + str(image_index) + '.pickle')
                model = model.to(device)
                images_since_last_save = 0

        image_index = 0
    print('finished training')
示例#19
0
    def _train(
        self,
        train_data,
        epoch,
        val_data=None,
        val_step=None,
        ckpt_step=None,
    ):
        """helper method, called by the fit method on each epoch.
        Iterates once through train_data, using it to update model parameters.
        Override this method if you need to implement your own training method.

        Parameters
        ----------
        train_data : torch.util.Dataloader
            instance that will be iterated over.
        """
        self.network.train()

        progress_bar = tqdm(train_data)
        for ind, batch in enumerate(progress_bar):
            x, y = batch[0].to(self.device), batch[1].to(self.device)
            y_pred = self.network.forward(x)
            self.optimizer.zero_grad()
            loss = self.loss(y_pred, y)
            loss.backward()
            self.optimizer.step()
            progress_bar.set_description(
                f'Epoch {epoch}, batch {ind}. Loss: {loss.item():.4f}. Global step: {self.global_step}'
            )

            if self.summary_writer is not None:
                self.summary_writer.add_scalar('loss/train', loss.item(),
                                               self.global_step)
            self.global_step += 1

            if val_data is not None:
                if self.global_step % val_step == 0:
                    log_or_print(
                        f'Step {self.global_step} is a validation step; computing metrics on validation set',
                        logger=self.logger,
                        level='info')
                    metric_vals = self._eval(val_data)
                    self.network.train()  # because _eval calls network.eval()
                    log_or_print(msg=', '.join([
                        f'{metric_name}: {metric_value:.4f}'
                        for metric_name, metric_value in metric_vals.items()
                        if metric_name.startswith('avg_')
                    ]),
                                 logger=self.logger,
                                 level='info')

                    if self.summary_writer is not None:
                        for metric_name, metric_value in metric_vals.items():
                            if metric_name.startswith('avg_'):
                                self.summary_writer.add_scalar(
                                    f'{metric_name}/val', metric_value,
                                    self.global_step)

                    current_val_acc = metric_vals['avg_acc']
                    if current_val_acc > self.max_val_acc:
                        self.max_val_acc = current_val_acc
                        log_or_print(
                            msg=
                            f'Accuracy on validation set improved. Saving max-val-acc checkpoint.',
                            logger=self.logger,
                            level='info')
                        self.save(self.max_val_acc_ckpt_path,
                                  epoch=epoch,
                                  global_step=self.global_step)
                        if self.patience:
                            self.patience_counter = 0
                    else:  # if accuracy did not improve
                        if self.patience:
                            self.patience_counter += 1
                            if self.patience_counter > self.patience:
                                log_or_print(
                                    'Stopping training early, '
                                    f'accuracy has not improved in {self.patience} validation steps.',
                                    logger=self.logger,
                                    level='info')
                                # save "backup" checkpoint upon stopping; don't save over "max-val-acc" checkpoint
                                self.save(self.ckpt_path,
                                          epoch=epoch,
                                          global_step=self.global_step)
                                progress_bar.close()
                                break
                            else:
                                log_or_print(
                                    f'Accuracy has not improved in {self.patience_counter} validation steps. '
                                    f'Not saving max-val-acc checkpoint for this validation step.',
                                    logger=self.logger,
                                    level='info')
                        else:  # patience is None. We still log that we are not saving checkpoint.
                            log_or_print(
                                'Accuracy is less than maximum validation accuracy so far. '
                                'Not saving max-val-acc checkpoint.',
                                logger=self.logger,
                                level='info')

            # below can be true regardless of whether we have val_data and/or current epoch is a val_epoch
            if self.global_step % ckpt_step == 0:
                log_or_print(f'Step {self.global_step} is a checkpoint step.',
                             logger=self.logger,
                             level='info')
                self.save(self.ckpt_path,
                          epoch=epoch,
                          global_step=self.global_step)
示例#20
0
    def train(self,
              epoch,
              max_epoch,
              writer,
              print_freq=10,
              fixbase_epoch=0,
              open_layers=None):
        losses_t = AverageMeter()
        losses_x = AverageMeter()
        accs = AverageMeter()
        batch_time = AverageMeter()
        data_time = AverageMeter()

        losses_t_pcb = AverageMeter()
        losses_x_pcb = AverageMeter()
        accs_pcb = AverageMeter()

        self.model.train()
        if (epoch + 1) <= fixbase_epoch and open_layers is not None:
            print('* Only train {} (epoch: {}/{})'.format(
                open_layers, epoch + 1, fixbase_epoch))
            open_specified_layers(self.model, open_layers)
        else:
            open_all_layers(self.model)

        num_batches = len(self.train_loader)
        end = time.time()

        # self.hamming(torch.tensor([0,1,1,0,1]),torch.tensor([0,1,0,0,0]))

        for batch_idx, data in enumerate(self.train_loader):
            data_time.update(time.time() - end)

            imgs, pids = self._parse_data_for_train(data)

            if self.use_gpu:
                imgs = imgs.cuda()
                pids = pids.cuda()
            self.optimizer.zero_grad()
            # outputs, features = self.model(imgs)
            outputs, features, h, b = self.model(imgs)
            # print(b_weighted)
            pids_g = self.parse_pids(pids)
            x = features

            # differ=h-rebuild
            # loss_con=torch.sum(torch.abs(differ))/(256*32)
            # print(loss_con)

            target_b = F.cosine_similarity(b[:pids_g.size(0) // 2],
                                           b[pids_g.size(0) // 2:])
            target_x = F.cosine_similarity(x[:pids_g.size(0) // 2],
                                           x[pids_g.size(0) // 2:])

            # print(loss_circle(h,pids))
            #loss_quantization=0#self.logcosh(h.abs() - 1).mean()
            #loss_average=0#(torch.abs(torch.sum(h)))/32
            loss1 = F.mse_loss(target_b, target_x)
            loss2 = torch.mean(
                torch.abs(
                    torch.pow(
                        torch.abs(h) - Variable(torch.ones(h.size()).cuda()),
                        3)))
            loss_greedy = loss1 + 0.1 * loss2  #+loss_average*0.5
            # print(loss_average)
            loss_batchhard_hash = self.compute_hashbatchhard(b, pids)

            #Loss_AMS=AMSoftmax(h.shape[1])
            #loss_ams=Loss_AMS(h,pids)
            loss_t = self._compute_loss(self.criterion_t, features, pids)
            loss_x = self._compute_loss(self.criterion_x, outputs, pids)
            # print("mgn:",mgn_part.shape)
            # test=mgn_part[:][:8]
            # print("test:",test.shape)
            # print(features.shape)

            # loss_mgn=Loss()
            # Loss_mgn=loss_mgn(mgn_part,pids)
            # loss_t_mgn = self._compute_loss(self.criterion_t, mgn_part[:], pids)
            # loss_x_mgn = self._compute_loss(self.criterion_x, mgn_part[4:], pids)
            loss = self.weight_t * loss_t + self.weight_x * loss_x + loss_greedy + loss_batchhard_hash * 2  #+loss_ams*0.5#+loss_x_mgn+loss_t_mgn#+loss_x_mgn+loss_t_mgn#+loss_ksh*0.1#+loss_class*0.01
            # loss_x_yh = self._compute_loss(self.criterion_x, y_h, pids)
            # loss=loss+loss_x_yh*0.5

            # feat=nn.functional.normalize(features)
            # inp_sp, inp_sn = convert_label_to_similarity(features, pids)
            # criterion = CircleLoss(m=0.25, gamma=80)
            #circle_loss = 0#criterion(inp_sp, inp_sn)/(32*2048)
            #loss=loss+circle_loss

            loss.backward()
            self.optimizer.step()

            batch_time.update(time.time() - end)

            losses_t.update(loss_t.item(), pids.size(0))
            losses_x.update(loss_x.item(), pids.size(0))
            accs.update(metrics.accuracy(outputs, pids)[0].item())

            if (batch_idx + 1) % print_freq == 0:
                # estimate remaining time
                eta_seconds = batch_time.avg * (num_batches - (batch_idx + 1) +
                                                (max_epoch -
                                                 (epoch + 1)) * num_batches)
                eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
                print(
                    'Epoch: [{0}/{1}][{2}/{3}]\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                    'Loss_t {loss_t.val:.4f} ({loss_t.avg:.4f})\t'
                    'Loss_x {loss_x.val:.4f} ({loss_x.avg:.4f})\t'
                    'Loss_g {loss_g:.4f} )\t'
                    'Loss_p {loss_p:.4f} )\t'
                    #'loss_ams {loss_ams:.4f} )\t'
                    'Acc {acc.val:.2f} ({acc.avg:.2f})\t'
                    'Lr {lr:.6f}\t'
                    'eta {eta}'.format(
                        epoch + 1,
                        max_epoch,
                        batch_idx + 1,
                        num_batches,
                        batch_time=batch_time,
                        data_time=data_time,
                        loss_t=losses_t,
                        loss_x=losses_x,
                        loss_g=loss_greedy,
                        loss_p=loss_batchhard_hash,
                        acc=accs,
                        #loss_ams=loss_ams,
                        # acc_pcb=accs_pcb,
                        # loss_t_pcb=losses_t_pcb,
                        # loss_x_pcb=losses_x_pcb,
                        lr=self.optimizer.param_groups[0]['lr'],
                        eta=eta_str))

            if writer is not None:
                n_iter = epoch * num_batches + batch_idx
                writer.add_scalar('Train/Time', batch_time.avg, n_iter)
                writer.add_scalar('Train/Data', data_time.avg, n_iter)
                writer.add_scalar('Train/Loss_t', losses_t.avg, n_iter)
                writer.add_scalar('Train/Loss_x', losses_x.avg, n_iter)
                writer.add_scalar('Train/Acc', accs.avg, n_iter)
                writer.add_scalar('Train/Lr',
                                  self.optimizer.param_groups[0]['lr'], n_iter)

            end = time.time()

        if self.scheduler is not None:
            self.scheduler.step()
示例#21
0
def do_contrastive_train(
    cfg,
    model,
    data_loader,
    criterion,
    optimizer,
    epoch,
    device=None,
    meters=None,
    logger=None,
):
    r"""Contrastive training implementation:

    Args:
        cfg: (CfgNode) Specified configuration. The final config settings
    integrates the initial default setting and user defined settings given
    by argparse.
        model: (nn.Module) Under the context of this repository, the model
    can be a simple PyTorch neural network or the instance wrapped by
    ContrastiveWrapper.
        data_loader:
        criterion:
        optimizer:
        device:
        epoch:
        meters:

    Returns:

    """
    # Capture display logger
    # logger = logging.getLogger("kknight")
    logger.info("Epoch {epoch} now started.".format(epoch=epoch))

    # Switch to train mode
    model.train()

    # Timers
    end = time.time()
    data_time, batch_time = 0, 0

    # Gradient accumulation interval and statistic display interval
    n_accum_grad = cfg.SOLVER.ACCUM_GRAD
    n_print_intv = n_accum_grad * cfg.SOLVER.DISP_INTERVAL
    max_iter = len(data_loader)

    for iteration, ((xis, xjs), _) in enumerate(data_loader):
        data_time += time.time() - end

        if device is not None:
            xis = xis.cuda(device, non_blocking=True)
            xjs = xjs.cuda(device, non_blocking=True)

        # Compute embedding and target label
        # output, target, extra = model(xis, xjs)
        output, target = model(xis, xjs)
        loss = criterion(output, target)

        # acc1/acc5 are (k + 1)-way constant classifier accuracy
        # measure accuracy and record loss
        acc1, acc5 = contrastive_accuracy(output, target, topk=(1, 5))
        # meters.update(loss=loss, **extra)
        meters.update(loss=loss)
        meters.update(acc1=acc1, acc5=acc5)
        loss.backward()

        # Compute batch time
        batch_time += time.time() - end
        end = time.time()

        if (iteration + 1) % n_accum_grad == 0 or iteration + 1 == max_iter:
            optimizer.step()
            # scheduler.step()
            optimizer.zero_grad()

            # Record batch time and data sampling time
            meters.update(time=batch_time, data=data_time)
            data_time, batch_time = 0, 0

        if (iteration + 1) % n_print_intv == 0 or iteration == max_iter:
            # Estimated time of arrival of remaining epoch
            total_eta = meters.time.global_avg * max_iter * (cfg.SOLVER.EPOCH -
                                                             epoch)
            eta_seconds = meters.time.global_avg * (max_iter -
                                                    iteration) + total_eta
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

            logger.info(
                meters.delimiter.join([
                    "eta: {eta}", "epoch: {epoch}", "iter: {iter}", "{meters}",
                    "lr: {lr:.6f}", "max mem: {memory:.0f}"
                ]).format(
                    eta=eta_string,
                    epoch=epoch,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024. / 1024.,
                ))
示例#22
0
def do_lincls_train(
    cfg,
    model,
    data_loader,
    criterion,
    optimizer,
    epoch,
    device=None,
    meters=None,
    logger=None,
):
    # Capture display logger
    # logger = logging.getLogger("kknight")
    logger.info("Epoch {epoch} now started.".format(epoch=epoch))

    # Switch to train mode
    model.eval()

    # Timers
    end = time.time()
    data_time, batch_time = 0, 0

    # Gradient accumulation interval and statistic display interval
    n_accum_grad = cfg.SOLVER.ACCUM_GRAD
    n_print_intv = n_accum_grad * cfg.SOLVER.DISP_INTERVAL
    max_iter = len(data_loader)

    for iteration, (images, target) in enumerate(data_loader):
        data_time += time.time() - end

        if device is not None:
            images = images.cuda(device, non_blocking=True)
            target = target.cuda(device, non_blocking=True)

        # Compute embedding and target label
        # output, target, extra = model(xis, xjs)
        output = model(images)
        loss = criterion(output, target)

        # acc1/acc5 are (k + 1)-way constant classifier accuracy
        # measure accuracy and record loss
        acc1, acc5 = contrastive_accuracy(output, target, topk=(1, 5))
        # meters.update(loss=loss, **extra)
        meters.update(loss=loss)
        meters.update(acc1=acc1, acc5=acc5)
        loss.backward()

        # Compute batch time
        batch_time += time.time() - end
        end = time.time()

        if (iteration + 1) % n_accum_grad == 0 or iteration + 1 == max_iter:
            optimizer.step()
            # scheduler.step()
            optimizer.zero_grad()

            # Record batch time and data sampling time
            meters.update(time=batch_time, data=data_time)
            data_time, batch_time = 0, 0

        if (iteration + 1) % n_print_intv == 0 or iteration == max_iter:
            # Estimated time of arrival of remaining epoch
            total_eta = meters.time.global_avg * max_iter * (cfg.SOLVER.EPOCH -
                                                             epoch)
            eta_seconds = meters.time.global_avg * (max_iter -
                                                    iteration) + total_eta
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

            logger.info(
                meters.delimiter.join([
                    "eta: {eta}", "epoch: {epoch}", "iter: {iter}", "{meters}",
                    "lr: {lr:.6f}", "max mem: {memory:.0f}"
                ]).format(
                    eta=eta_string,
                    epoch=epoch,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024. / 1024.,
                ))
def train(net: nn.Module,
          train_dataloader: DataLoader = None,
          val_dataloader: DataLoader = None,
          test_dataloader: DataLoader = None,
          is_earlystopping: bool = True) -> nn.Module:
    """
    Training loop iterating on the train dataloader and updating the model's weights.
    Inferring the validation dataloader & test dataloader, if given, to babysit the learning
    Activating cuda device if available.
    :return: Trained model
    """
    train_losses: np.array = np.zeros(NUM_EPOCHS)
    val_losses: np.array = np.zeros(NUM_EPOCHS)
    best_epoch: int = NUM_EPOCHS - 1

    if test_dataloader:
        untrained_test_loss, untrained_y_test_pred = infer(
            net, test_dataloader, loss_fn)
        _, _ = get_num_of_areas_and_targets_from_arary(array=y_test)
        print(f'Test Loss before training: {untrained_test_loss:.3f}')
        _, _, _ = calculate_model_metrics(y_true=y_test,
                                          y_pred=untrained_y_test_pred,
                                          verbose=True)

    for epoch in range(NUM_EPOCHS):
        print(f'*************** Epoch {epoch + 1} ***************')
        net.train()
        h = net.init_hidden(batch_size=BATCH_SIZE)
        for batch_idx, (x_train, y_train) in enumerate(tqdm(train_dataloader)):
            if train_on_gpu:
                net.cuda()
                x_train, y_train = x_train.cuda(), y_train.cuda()
            h = h.data
            optimizer.zero_grad()
            y_train_pred, h = net(x_train, h)
            loss = loss_fn(y_train_pred, y_train)
            loss.backward()
            optimizer.step()

        if val_dataloader:
            val_loss, y_val_pred = infer(net, val_dataloader, loss_fn)
            val_losses[epoch] = val_loss

        if is_earlystopping and check_earlystopping(loss=val_losses,
                                                    epoch=epoch):
            print('EarlyStopping !!!')
            best_epoch = np.argmin(val_losses[:epoch + 1])
            break
        train_losses[epoch] = loss.item() / len(train_dataloader)
        scheduler.step(
            val_loss)  # Change the lr if needed based on the validation loss

        if epoch % PRINT_EVERY == 0:
            print(f"Epoch: {epoch + 1}/{NUM_EPOCHS},",
                  f"Train loss: {train_losses[epoch]:.5f},",
                  f"Validation loss: {val_losses[epoch]:.5f}")

            _, _, _ = calculate_model_metrics(y_true=y_train,
                                              y_pred=y_train_pred,
                                              mode='Train-Last Batch')
            if val_dataloader:
                _, _, _ = calculate_model_metrics(y_true=y_val,
                                                  y_pred=y_val_pred,
                                                  mode='Validation')

        if (epoch + 1) % SAVE_EVERY == 0:
            save_pt_model(net=net)

    if best_epoch != NUM_EPOCHS - 1:  # earlystopping NOT activated
        train_losses = train_losses[:best_epoch + 1]
        val_losses = val_losses[:best_epoch + 1]
    else:
        best_epoch = np.argmin(val_losses)

    print(
        f'Best Epoch: {best_epoch + 1}; Best Validation Loss: {val_losses[best_epoch]:.4f}'
    )
    print(train_losses)
    plot_values_by_epochs(train_values=train_losses,
                          validation_values=val_losses)
    return net
示例#24
0
        C_loss_global = torch.sum(
            torch.stack([cross_entropy_loss(output, labels) for output in logits_global_list]), dim=0)

        C_loss = C_loss_local_rest + C_loss_global + C_loss_local + C_loss_rest

        loss = T_loss + 2 * C_loss
        
        losses.update(loss.data.item(), labels.size(0))
        prec1 = (sum([accuracy(output.data, labels.data)[0].item() for output in logits_local_rest_list])
                 + sum([accuracy(output.data, labels.data)[0].item() for output in logits_global_list])
                 + sum([accuracy(output.data, labels.data)[0].item() for output in logits_local_list])
                 + sum([accuracy(output.data, labels.data)[0].item() for output in logits_rest_list]))/(12+12+12+9)
        precisions.update(prec1, labels.size(0))

        loss.backward()

        optimizer.step()
        batch_time.update(time.time() - end)
        end = time.time()
        if (i + 1) % args.steps_per_log == 0:
                print('Epoch: [{}][{}/{}]\t'
                      'Time {:.3f} ({:.3f})\t'
                      'Data {:.3f} ({:.3f})\t'
                      'Loss {:.3f} ({:.3f})\t'
                      'Prec {:.2%} ({:.2%})\t'
                      .format(epoch, i + 1, len(train_loader),
                              batch_time.val, args.steps_per_log*batch_time.avg,
                              data_time.val, args.steps_per_log*data_time.avg,
                              losses.val, losses.avg,
                              precisions.val, precisions.avg))
示例#25
0
    def train(self):
        """ Perform training of the network. """

        num_epochs = 50
        batch_size = 16
        batches_per_epoch = 1024
        learning_rate = 0.02

        optimizer = torch.optim.SGD(self._net.parameters(), lr=learning_rate)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [40, 45],
                                                         gamma=0.1,
                                                         last_epoch=-1)

        training_start_time = time.time()

        self.validate()

        for epoch in range(num_epochs):
            print("Epoch ------ ", epoch)

            train_gen = self._dispatcher.train_gen(batches_per_epoch,
                                                   batch_size)

            self._net.train()

            for batch_index, batch in enumerate(train_gen):
                if self.use_gpu:
                    batch.cuda()

                pred = self._net.forward(batch.image_tensor)

                loss, details = self._net.loss(pred, batch.segmentation_tensor)

                if batch_index % 50 == 0:
                    print("epoch={} batch={} loss={:.4f}".format(
                        epoch, batch_index, loss.item()))
                    self._render_prediction(
                        pred.detach().cpu().numpy()[0],
                        batch.segmentation_tensor.detach().cpu().numpy()[0],
                        batch.image_tensor.detach().cpu().numpy()[0].transpose(
                            (1, 2, 0)))
                    print("-------------------------------")

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                pass

            scheduler.step()

            # Save after every epoch
            torch.save(self._net.state_dict(), self._snapshot_name)

            # Validate every epoch
            self.validate()

            pass
            # end of epoch

        training_end_time = time.time()
        print("Training took {} hours".format(
            (training_end_time - training_start_time) / 3600))

        print("Train finished!")
示例#26
0
def train(net: nn.Module, optimizer: torch.optim, train_dataloader: DataLoader = None,
          val_dataloader: DataLoader = None, infer_df: np.array = None, is_earlystopping: bool = False) -> nn.Module:
    """
    Training loop iterating on the train dataloader and updating the model's weights.
    Inferring the validation dataloader & test dataloader, if given, to babysit the learning
    Activating cuda device if available.
    :return: Trained model
    """
    NUMBER_OF_PREDS: int = len(train_dataloader.dataset) * NUM_USERS
    train_losses: np.array = np.zeros(NUM_EPOCHS)
    train_accuracy: np.array = np.zeros(NUM_EPOCHS)
    val_losses: np.array = np.zeros(NUM_EPOCHS)
    val_accuracy: np.array = np.zeros(NUM_EPOCHS)
    train_positive_pred: int = 0
    train_positive_number: int = 0
    best_epoch: int = NUM_EPOCHS - 1

    if val_dataloader:
        untrained_val_loss, untrained_val_accuracy = infer(net=net, infer_dataloader=val_dataloader, loss_fn=loss_fn,
                                                           infer_df=infer_df)
        print(f'Validation Loss before training: {untrained_val_loss:.5f}')

    for epoch in range(NUM_EPOCHS):
        print(f'*************** Epoch {epoch + 1} ***************')
        train_correct_counter = 0
        loss_running = 0

        net.train()
        for x_train, y_train in tqdm(train_dataloader):
            if train_on_gpu:
                net.cuda()
                x_train, y_train = x_train.cuda(), y_train.cuda()
            optimizer.zero_grad()
            y_train_pred = net(x_train)

            loss = loss_fn(y_train_pred.flatten(), y_train.flatten())
            loss_running += loss.item()
            loss.backward()
            optimizer.step()
            train_preds = np.where(y_train_pred > 0.5, 1, 0)
            train_correct_counter += (train_preds == np.array(y_train)).sum()
            train_positive_number += get_number_of_positves(y=y_train)
            train_positive_pred += get_number_of_tp(y_true=y_train, y_pred=train_preds)

        train_losses[epoch] = loss_running / len(train_dataloader)
        train_accuracy[epoch] = train_correct_counter.item() / NUMBER_OF_PREDS
        train_recall = train_positive_pred / train_positive_number * 100

        if val_dataloader:
            val_loss, val_acc = infer(net=net, infer_dataloader=val_dataloader, loss_fn=loss_fn, infer_df=infer_df)
            val_losses[epoch] = val_loss
            val_accuracy[epoch] = val_acc

        if is_earlystopping and val_dataloader and check_earlystopping(loss=val_losses, epoch=epoch):
            print('EarlyStopping !!!')
            best_epoch = np.argmin(val_losses[:epoch + 1])
            break
        if epoch % PRINT_EVERY == 0:
            print(f"Epoch: {epoch + 1}/{NUM_EPOCHS},",
                  f"Train loss: {train_losses[epoch]:.5f}, Train Num Correct: {train_correct_counter} "
                  f"/ {NUMBER_OF_PREDS}, Train Accuracy: {train_accuracy[epoch]:.3f}, Train Recall: {train_recall:.3f}")

            if val_dataloader:
                print(f"Validation loss: {val_losses[epoch]:.5f}, Validation Accuracy: {val_accuracy[epoch]:.3f}")

        if (epoch + 1) % SAVE_EVERY == 0:
            save_pt_model(net=net)

    if best_epoch != NUM_EPOCHS - 1:  # Earlystopping NOT activated
        train_losses = train_losses[:best_epoch + 1]
        val_losses = val_losses[:best_epoch + 1]
    else:
        best_epoch = np.argmin(val_losses)

    print(
        f'Best Epoch: {best_epoch + 1}; Best Validation Loss: {val_losses[best_epoch]:.4f}')
    if val_dataloader:
        print('val_accuracy', val_accuracy)
        print('val_loss', val_loss)
    print(train_losses)
    plot_values_by_epochs(train_values=train_losses, test_values=val_losses)
    return net
    def train(
            self,
            epoch,
            max_epoch,
            writer,
            print_freq=10,
            fixbase_epoch=0,
            open_layers=None
    ):
        losses_t = AverageMeter()
        losses_x = AverageMeter()
        losses_recons = AverageMeter()
        accs = AverageMeter()
        batch_time = AverageMeter()
        data_time = AverageMeter()

        self.model.train()
       
        open_all_layers(self.model)

        num_batches = len(self.train_loader)
        end = time.time()
        for batch_idx, data in enumerate(self.train_loader):
            data_time.update(time.time() - end)

            imgs, pids = self._parse_data_for_train(data)
            imgs_clean=imgs.clone()
            if self.use_gpu:
                imgs = imgs.cuda()
                imgs_clean = imgs_clean.cuda()
                pids = pids.cuda()
            labelss=[]
            if epoch >= 0 and epoch < 15:
                randmt = RandomErasing(probability=0.5,sl=0.07, sh=0.3)
                for i, img in enumerate(imgs):
                   
                   imgs[i],p = randmt(img)
                   labelss.append(p)
               
            if epoch >= 15:
                randmt = RandomErasing(probability=0.5,sl=0.1, sh=0.3)
                for i, img in enumerate(imgs):
                   
                   imgs[i],p = randmt(img)
                   labelss.append(p)

            binary_labels = torch.tensor(np.asarray(labelss)).cuda()
            self.optimizer.zero_grad()
            
            outputs, outputs2, recons,bin_out1,bin_out2, bin_out3 = self.model(imgs )
            loss_mse = self.criterion_mse(recons, imgs_clean)
            loss = self.mgn_loss(outputs, pids)
            
            occ_loss1 = self.BCE_criterion(bin_out1.squeeze(1),binary_labels.float() )
            occ_loss2 = self.BCE_criterion(bin_out2.squeeze(1),binary_labels.float() )
            occ_loss3 = self.BCE_criterion(bin_out3.squeeze(1),binary_labels.float() )


            loss = loss + .05*loss_mse + 0.1*occ_loss1 + 0.1*occ_loss2+0.1*occ_loss3
            #loss = self.weight_t * loss_t + self.weight_x * loss_x #+ #self.weight_r*loss_mse
            loss.backward()
            self.optimizer.step()

            batch_time.update(time.time() - end)

            #losses_t.update(loss_t.item(), pids.size(0))
            losses_x.update(loss.item(), pids.size(0))
            losses_recons.update(occ_loss1.item(), binary_labels.size(0))
            accs.update(metrics.accuracy(outputs, pids)[0].item())

            if (batch_idx + 1) % print_freq == 0:
                # estimate remaining time
                eta_seconds = batch_time.avg * (
                        num_batches - (batch_idx + 1) + (max_epoch -
                                                         (epoch + 1)) * num_batches
                )
                eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
                print(
                    'Epoch: [{0}/{1}][{2}/{3}]\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                    #'Loss_t {loss_t.val:.4f} ({loss_t.avg:.4f})\t'
                    'Loss_x {loss_x.val:.4f} ({loss_x.avg:.4f})\t'
                    'Loss_Occlusion {loss_r.val:.4f} ({loss_r.avg:.4f})\t'             
                    'Acc {acc.val:.2f} ({acc.avg:.2f})\t'
                    'Lr {lr:.6f}\t'
                    'eta {eta}'.format(
                        epoch + 1,
                        max_epoch,
                        batch_idx + 1,
                        num_batches,
                        batch_time=batch_time,
                        data_time=data_time,
                        #loss_t=losses_t,
                        loss_x=losses_x,
                        loss_r = losses_recons,
                        acc=accs,
                        lr=self.optimizer.param_groups[0]['lr'],
                        eta=eta_str
                    )
                )
            writer= None
            if writer is not None:
                n_iter = epoch * num_batches + batch_idx
                writer.add_scalar('Train/Time', batch_time.avg, n_iter)
                writer.add_scalar('Train/Data', data_time.avg, n_iter)
                writer.add_scalar('Train/Loss_t', losses_t.avg, n_iter)
                writer.add_scalar('Train/Loss_x', losses_x.avg, n_iter)
                writer.add_scalar('Train/Acc', accs.avg, n_iter)
                writer.add_scalar(
                    'Train/Lr', self.optimizer.param_groups[0]['lr'], n_iter
                )

            end = time.time()

        if self.scheduler is not None:
            self.scheduler.step()
def train(net: nn.Module,
          optimizer: torch.optim,
          train_dataloader: DataLoader = None,
          val_dataloader: DataLoader = None,
          is_earlystopping: bool = False) -> nn.Module:
    """
    Training loop iterating on the train dataloader and updating the model's weights.
    Inferring the validation dataloader & test dataloader, if given, to babysit the learning
    Activating cuda device if available.
    :return: Trained model
    """
    train_losses: np.array = np.zeros(NUM_EPOCHS)
    train_accuracy: np.array = np.zeros(NUM_EPOCHS)
    val_losses: np.array = np.zeros(NUM_EPOCHS)
    val_accuracy: np.array = np.zeros(NUM_EPOCHS)
    train_auc: np.array = np.zeros(NUM_EPOCHS)
    val_auc: np.array = np.zeros(NUM_EPOCHS)
    best_epoch: int = NUM_EPOCHS - 1

    if val_dataloader:
        untrained_test_loss, untrained_test_accuracy, untrained_test_auc = infer(
            net, val_dataloader, loss_fn)
        print(f'Test Loss before training: {untrained_test_loss:.5f}')

    for epoch in range(NUM_EPOCHS):
        print(f'*************** Epoch {epoch + 1} ***************')
        train_correct_counter = 0
        train_auc_accumulated = 0
        loss_running = 0
        net.train()
        for x_train, y_train in tqdm(train_dataloader):
            if x_train.shape[-1] == 224:
                y_train = torch.tensor(np.where(y_train == 3, 0, 1)).long()
            if train_on_gpu:
                net.cuda()
                x_train, y_train = x_train.cuda(), y_train.cuda()
            optimizer.zero_grad()
            y_train_pred = net(x_train)

            loss = loss_fn(y_train_pred, y_train)
            loss_running += loss.item()
            loss.backward()
            optimizer.step()
            _, train_preds = torch.max(y_train_pred, dim=1)
            train_correct_counter += torch.sum(train_preds == y_train)
            train_auc_accumulated += calculate_auc_score(y_true=y_train,
                                                         y_pred=train_preds)

        train_losses[epoch] = loss_running / len(train_dataloader)
        train_accuracy[epoch] = train_correct_counter.item() / len(
            train_dataloader.dataset)
        train_auc[epoch] = train_auc_accumulated / len(train_dataloader)

        if val_dataloader:
            val_loss, val_acc, val_auc_val = infer(net, val_dataloader,
                                                   loss_fn)
            val_losses[epoch] = val_loss
            val_accuracy[epoch] = val_acc
            val_auc[epoch] = val_auc_val

        if is_earlystopping and check_earlystopping(loss=val_losses,
                                                    epoch=epoch):
            print('EarlyStopping !!!')
            best_epoch = np.argmin(val_losses[:epoch + 1])
            break

        if epoch % PRINT_EVERY == 0:
            print(
                f"Epoch: {epoch + 1}/{NUM_EPOCHS},",
                f"Train loss: {train_losses[epoch]:.5f}, Train Num Correct: {train_correct_counter} "
                f"/ {len(train_dataloader.dataset)}, Train Accuracy: {train_accuracy[epoch]:.3f}\n",
                f"Validation loss: {val_losses[epoch]:.5f}, Validation Accuracy: {val_accuracy[epoch]:.3f}",
                f"Validation AUC: {val_auc[epoch]:.5f}, Train AUC: {train_auc[epoch]:.5f}"
            )

        if (epoch + 1) % SAVE_EVERY == 0:
            save_pt_model(net=net)

    if best_epoch != NUM_EPOCHS - 1:  # earlystopping NOT activated
        train_losses = train_losses[:best_epoch + 1]
        val_losses = val_losses[:best_epoch + 1]
    else:
        best_epoch = np.argmin(val_losses)

    print(
        f'Best Epoch: {best_epoch + 1}; Best Validation Loss: {val_losses[best_epoch]:.4f}'
    )
    print('val_accuracy', val_accuracy)
    print('val_loss', val_loss)
    print(train_losses)
    plot_values_by_epochs(train_values=train_losses, test_values=val_losses)
    return net