Example #1
0
    def extract_feature(self, data_loader):
        print_freq = 50
        self.cnn_model.eval()
        self.att_model.eval()

        batch_time = AverageMeter()
        data_time = AverageMeter()
        end = time.time()

        allfeatures = 0
        allfeatures_raw = 0

        for i, (imgs, flows, _, _) in enumerate(data_loader):
            imgs = to_torch(imgs)
            flows = to_torch(flows)
            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
            imgs = imgs.to(device)
            flows = flows.to(device)
            with torch.no_grad():
                if i == 0:
                    out_feat, out_raw = self.cnn_model(imgs, flows, self.mode)
                    out_feat, out_raw = self.att_model.selfpooling_model(out_feat, out_raw)
                    allfeatures = out_feat
                    allfeatures_raw = out_raw
                    preimgs = imgs
                    preflows = flows
                elif imgs.size(0) < data_loader.batch_size:
                    flaw_batchsize = imgs.size(0)
                    cat_batchsize = data_loader.batch_size - flaw_batchsize
                    imgs = torch.cat((imgs, preimgs[0:cat_batchsize]), 0)
                    flows = torch.cat((flows, preflows[0:cat_batchsize]), 0)

                    out_feat, out_raw = self.cnn_model(imgs, flows, self.mode)
                    out_feat, out_raw = self.att_model.selfpooling_model(out_feat, out_raw)

                    out_feat = out_feat[0:flaw_batchsize]
                    out_raw = out_feat[0:flaw_batchsize]

                    allfeatures = torch.cat((allfeatures, out_feat), 0)
                    allfeatures_raw = torch.cat((allfeatures_raw, out_raw), 0)
                else:
                    out_feat, out_raw = self.cnn_model(imgs, flows, self.mode)
                    out_feat, out_raw = self.att_model.selfpooling_model(out_feat, out_raw)

                    allfeatures = torch.cat((allfeatures, out_feat), 0)
                    allfeatures_raw = torch.cat((allfeatures_raw, out_raw), 0)

            batch_time.update(time.time() - end)
            end = time.time()

            if (i + 1) % print_freq == 0:
                print('Extract Features: [{}/{}]\t'
                      'Time {:.3f} ({:.3f})\t'
                      'Data {:.3f} ({:.3f})\t'
                      .format(i + 1, len(data_loader),
                              batch_time.val, batch_time.avg,
                              data_time.val, data_time.avg))

        return allfeatures, allfeatures_raw
class Train_classifier:
    def __init__(self, opt):

        self.opt = opt
        torch.manual_seed(opt.seed)

        print('=========user config==========')
        pprint(opt._state_dict())
        print('============end===============')

        self.trainloader, self.valloader = load_dataloaders(self.opt)

        self.use_gpu = opt.use_gpu
        self.device = torch.device('cuda')

        self._init_model()
        self._init_criterion()
        self._init_optimizer()

        self.model_dir = os.path.join(
            self.opt.save_dir,
            str(self.opt.model_type) + "_" + str(self.opt.model_name))

        Path(self.model_dir).mkdir(parents=True, exist_ok=True)

        if not os.path.exists(
                os.path.join(self.model_dir, 'reconstructed_images')):
            os.makedirs(os.path.join(self.model_dir, 'reconstructed_images'))

        if (self.opt.debug == False):
            self.experiment = wandb.init(project="cycle_consistent_vae")
            hyper_params = self.opt._state_dict()
            self.experiment.config.update(hyper_params)
            wandb.watch(self.encoder)
            wandb.watch(self.classifier_model)

    def _init_model(self):

        self.encoder, self.classifier_model = load_model(self.opt)

        print("LEARNING RATE: ", self.opt.base_learning_rate)

        self.X_1 = torch.FloatTensor(self.opt.batch_size,
                                     self.opt.num_channels,
                                     self.opt.image_size, self.opt.image_size)
        self.X_2 = torch.FloatTensor(self.opt.batch_size,
                                     self.opt.num_channels,
                                     self.opt.image_size, self.opt.image_size)
        self.X_3 = torch.FloatTensor(self.opt.batch_size,
                                     self.opt.num_channels,
                                     self.opt.image_size, self.opt.image_size)

        self.style_latent_space = torch.FloatTensor(self.opt.batch_size,
                                                    self.opt.style_dim)

        if self.use_gpu:
            self.device = torch.device('cuda')

            self.encoder.cuda()
            self.classifier_model.cuda()

            self.X_1 = self.X_1.cuda()
            self.X_2 = self.X_2.cuda()
            self.X_3 = self.X_3.cuda()

            self.style_latent_space = self.style_latent_space.cuda()

        self.load_encoder()

    def _init_optimizer(self):
        """
		optimizer and scheduler definition
		"""
        self.optimizer = optim.Adam(self.classifier_model.parameters(),
                                    lr=self.opt.base_learning_rate,
                                    betas=(self.opt.beta1, self.opt.beta2))
        # divide the learning rate by a factor of 10 after 80 epochs
        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer,
                                                   step_size=80,
                                                   gamma=0.1)

    def _init_criterion(self):
        """
		loss definitions
		"""
        self.cross_entropy_loss = nn.CrossEntropyLoss()
        self.cross_entropy_loss.cuda()

    def train(self):

        sys.stdout = Logger(osp.join(self.model_dir, 'log_train.txt'))

        print("TRAINING CLASSIFIER...")

        start_epoch = self.opt.start_epoch
        best_acc = 0.0
        best_epoch = 0

        for epoch in range(start_epoch, self.opt.total_epochs):

            print('')
            print(
                'Epoch #' + str(epoch) +
                '..........................................................................'
            )

            self.train_one_epoch(epoch)
            val_acc = self.evaluation(epoch)

            self.scheduler.step()

            if (val_acc > best_acc):
                best_acc = val_acc

                self.save_model(best_model=True)

            # break

        self.save_model(best_model=False)
        if (self.opt.debug == False):

            print("UPLOADING FINAL FILES ...")

            wandb.save(self.model_dir + "/*")
            # wandb.save(os.path.join(self.opt.save_dir, self.opt.model_name,
            # 			str(self.opt.model_type)+ "_classifier.pth"))

            # wandb.save(os.path.join(self.opt.save_dir, self.opt.model_name,
            # 			str(self.opt.model_type)+ "_best_classifier.pth"))

    def train_one_epoch(self, epoch):

        # self.encoder.eval()
        self.classifier_model.train()

        self.cross_entropy_losses = AverageMeter()
        self.accuracy = AverageMeter()
        correct = 0
        total = 0

        for batch_idx, data in enumerate(self.trainloader):

            image_batch_1, image_batch_2, labels = data
            labels = labels.cuda()
            # labels = torch.FloatTensor(labels).cuda()
            self.optimizer.zero_grad()
            self.X_1.copy_(image_batch_1)
            self.style_mu_1, self.style_logvar_1, self.class_latent_space_1 = self.encoder(
                Variable(self.X_1))
            style_latent_space_1 = reparameterize(training=False,
                                                  mu=self.style_mu_1,
                                                  logvar=self.style_logvar_1)

            if (self.opt.model_type == "specified"):

                outputs = self.classifier_model(style_latent_space_1)

            else:

                outputs = self.classifier_model(self.class_latent_space_1)

            self.loss = self.cross_entropy_loss(outputs, labels)

            self.loss.backward()
            self.optimizer.step()

            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            self.cross_entropy_losses.update(self.loss.item())

            del (image_batch_1)
            del (image_batch_2)
            del (outputs)
            del (labels)

            # break

        self.accuracy.update(correct / total)
        self._print_values(epoch)

    def _print_values(self, epoch):

        if (self.opt.debug == False):
            self.experiment.log(
                {'Cross Entropy loss': self.cross_entropy_losses.mean},
                step=epoch)
            self.experiment.log({'Train Accuracy': self.accuracy.mean},
                                step=epoch)

        print('Cross Entropy loss: ' + str(self.cross_entropy_losses.mean))
        print('Train Accuracy: ' + str(self.accuracy.mean))

    def evaluation(self, epoch):
        print("Evaluating Model ...")

        # self.encoder.eval()
        self.classifier_model.eval()

        self.val_accuracy = AverageMeter()
        correct = 0
        total = 0
        with torch.no_grad():

            for batch_idx, data in enumerate(self.valloader):

                image_batch_1, image_batch_2, labels = data

                self.X_1.copy_(image_batch_1)

                labels = labels.cuda()

                self.style_mu_1, self.style_logvar_1, self.class_latent_space_1 = self.encoder(
                    Variable(self.X_1))
                style_latent_space_1 = reparameterize(
                    training=False,
                    mu=self.style_mu_1,
                    logvar=self.style_logvar_1)

                if (self.opt.model_type == "specified"):
                    outputs = self.classifier_model(style_latent_space_1)

                else:
                    outputs = self.classifier_model(self.class_latent_space_1)

                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

                # break

        self.val_accuracy.update(correct / total)

        if (self.opt.debug == False):
            self.experiment.log(
                {'Validation Accuracy': self.val_accuracy.mean}, step=epoch)

        print('Validation Accuracy: ' + str(self.val_accuracy.mean))
        print("Correct: ", correct)
        print("Total: ", total)
        return self.val_accuracy.mean

    def visualization(self):

        self.load_classifier(True)

        self.classifier_model.eval()
        val_iter = iter(self.valloader)
        data = val_iter.next()

        with torch.no_grad():
            image_batch_1, image_batch_2, labels = data

            self.X_1.copy_(image_batch_1)

            labels = labels.cuda()

            self.style_mu_1, self.style_logvar_1, self.class_latent_space_1 = self.encoder(
                Variable(self.X_1))
            style_latent_space_1 = reparameterize(training=False,
                                                  mu=self.style_mu_1,
                                                  logvar=self.style_logvar_1)

            if (self.opt.model_type == "specified"):

                outputs = self.classifier_model(style_latent_space_1)

            else:
                outputs = self.classifier_model(self.class_latent_space_1)

            _, predicted = outputs.max(1)
            image_batch = (np.transpose(self.X_1.cpu().numpy(), (0, 2, 3, 1)))

            labels = labels.detach().cpu().numpy()

            predicted = predicted.detach().cpu().numpy()

        shape = [2, 8]
        fig = plt.figure(1)
        grid = ImageGrid(fig, 111, nrows_ncols=shape, axes_pad=0.05)
        size = shape[0] * shape[1]

        for i in range(size):
            current_image = image_batch[i]

            current_image = current_image * 255
            current_image = current_image.astype("uint8")

            current_image = cv2.UMat(current_image).get()

            if (labels[i] == predicted[i]):

                cv2.rectangle(current_image, (0, 0), (60, 60), (0, 255, 0), 2)

            else:

                cv2.rectangle(current_image, ((0), 0), (60, 60), (255, 0, 0),
                              2)

            grid[i].axis('off')
            grid[i].imshow(
                current_image)  # The AxesGrid object work as a list of axes.

        print("SAVING IMAGES")
        path = os.path.join(self.model_dir, 'missclassification.png')
        plt.savefig(path)
        plt.clf()

    def save_model(self, best_model):
        print("SAVING MODEL ...")

        if (best_model):
            torch.save(self.classifier_model.state_dict(),
                       os.path.join(self.model_dir, "best_classifier.pth"))

        else:
            torch.save(self.classifier_model.state_dict(),
                       os.path.join(self.model_dir, "classifier.pth"))

    def load_encoder(self):

        print("[*] LOADING ENCODER: {}".format(
            os.path.join(self.opt.save_dir, "vae", "encoder.pth")))
        self.encoder.load_state_dict(
            torch.load(os.path.join(self.opt.save_dir, "vae", "encoder.pth")))

        self.encoder.cuda()

    def load_decoder(self):

        print("[*] LOADING DECODER: {}".format(
            os.path.join(self.opt.save_dir, "vae", "decoder.pth")))
        self.decoder.load_state_dict(
            torch.load(os.path.join(self.opt.save_dir, "vae", "decoder.pth")))
        self.decoder.cuda()

    def load_classifier(self, best_model):

        if (best_model):
            print("LOADING BEST MODEL")

            self.classifier_model.load_state_dict(
                torch.load(os.path.join(self.model_dir,
                                        "best_classifier.pth")))

        else:

            self.classifier_model.load_state_dict(
                torch.load(os.path.join(self.model_dir, "_classifier.pth")))

        self.classifier_model.cuda()
Example #3
0
    def train(self,
              epoch,
              data_loader,
              optimizer1,
              optimizer2,
              optimizer3,
              print_freq=1):
        self.img_model.train()
        self.diff_model.train()
        self.depth_model.train()

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        precisions = AverageMeter()

        end = time.time()
        for i, inputs in enumerate(data_loader):
            data_time.update(time.time() - end)

            inputs, targets = self._parse_data(inputs)
            loss, prec1 = self._forward(inputs, targets)

            losses.update(loss.item(), targets.size(0))
            precisions.update(prec1, targets.size(0))

            optimizer1.zero_grad()
            optimizer2.zero_grad()
            optimizer3.zero_grad()
            loss.backward()
            optimizer1.step()
            optimizer2.step()
            optimizer3.step()

            batch_time.update(time.time() - end)
            end = time.time()

            if (i + 1) % print_freq == 0:
                print('Epoch: [{}][{}/{}]\t'
                      'Time {:.3f} ({:.3f})\t'
                      'Data {:.3f} ({:.3f})\t'
                      'Loss {:.3f} ({:.3f})\t'
                      'Prec {:.2%} ({:.2%})\t'.format(
                          epoch, i + 1, len(data_loader), batch_time.val,
                          batch_time.avg, data_time.val, data_time.avg,
                          losses.val, losses.avg, precisions.val,
                          precisions.avg))
Example #4
0
    def train(self,damage_initial_previous_frame_mask=True,lossfunc='cross_entropy',model_resume=False):
        ###################
        self.model.train()
        running_loss = AverageMeter()
        #optimizer = optim.SGD([{'params':self.model.feature_extracter.parameters()},{'params':self.model.semantic_embedding.parameters()},{'params':self.model.dynamic_seghead.parameters()}],lr=cfg.TRAIN_LR,momentum=cfg.TRAIN_MOMENTUM)
        optimizer = optim.SGD(self.model.parameters(),lr=cfg.TRAIN_LR,momentum=cfg.TRAIN_MOMENTUM,weight_decay=cfg.TRAIN_WEIGHT_DECAY)
        #scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=cfg.TRAIN_LR_STEPSIZE,gamma=cfg.TRAIN_LR_GAMMA)
        

        ###################

        composed_transforms = transforms.Compose([tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP),
                                                    tr.RandomScale(),
                                                     tr.RandomCrop((cfg.DATA_RANDOMCROP,cfg.DATA_RANDOMCROP)),
                                                     tr.Resize(cfg.DATA_RESCALE),
                                                     tr.ToTensor()])
        composed_transforms_ytb = transforms.Compose([tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP),
                                                    tr.RandomScale([0.5,1,1.25]),
                                                     tr.RandomCrop((800,800)),
                                                     tr.Resize(cfg.DATA_RESCALE),
                                                     tr.ToTensor()])
        print('dataset processing...')
#        train_dataset = DAVIS2017_Train(root=cfg.DATA_ROOT, transform=composed_transforms)
        train_dataset = DAVIS2017_VOS_Train(root=cfg.DATA_ROOT, transform=composed_transforms)
        ytb_train_dataset = YTB_VOS_Train(root=cfg.YTB_DATAROOT,transform=composed_transforms_ytb)
#        trainloader = DataLoader(train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE,
#                        sampler = RandomIdentitySampler(train_dataset.sample_list), 
#                        shuffle=False,num_workers=cfg.NUM_WORKER,pin_memory=True)
        trainloader_davis = DataLoader(train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE,
                        shuffle=True,num_workers=cfg.NUM_WORKER,pin_memory=True)
        trainloader_ytb = DataLoader(ytb_train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE,
                        shuffle=True,num_workers=cfg.NUM_WORKER,pin_memory=True)
        #trainloader=[trainloader_ytb,trainloader_davis]
        trainloader=[trainloader_ytb,trainloader_davis]
        print('dataset processing finished.')
        if lossfunc=='bce':
            criterion = Added_BCEWithLogitsLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS,cfg.TRAIN_HARD_MINING_STEP)
        elif lossfunc=='cross_entropy':
            criterion = Added_CrossEntropyLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS,cfg.TRAIN_HARD_MINING_STEP)
        else:
            print('unsupported loss funciton. Please choose from [cross_entropy,bce]')

        max_itr = cfg.TRAIN_TOTAL_STEPS

        step=0

        if model_resume:
            saved_model_=os.path.join(self.save_res_dir,'save_step_60000.pth')

            saved_model_ = torch.load(saved_model_)
            self.model=self.load_network(self.model,saved_model_)
            step=60000
            print('resume from step {}'.format(step))
        while step<cfg.TRAIN_TOTAL_STEPS:

            for train_dataloader in trainloader:
 #       sample['meta']={'seq_name':seqname,'frame_num':frame_num,'obj_num':obj_num}
                for ii, sample in enumerate(train_dataloader):
        #            print(ii)
                    now_lr=self._adjust_lr(optimizer,step,max_itr)
                    ref_imgs = sample['ref_img'] #batch_size * 3 * h * w
                    img1s = sample['img1'] 
                    img2s = sample['img2']
                    ref_scribble_labels = sample['ref_scribble_label'] #batch_size * 1 * h * w
                    label1s = sample['label1']
                    label2s = sample['label2']
                    seq_names = sample['meta']['seq_name'] 
                    obj_nums = sample['meta']['obj_num']
                    bs,_,h,w = img2s.size()
                    inputs = torch.cat((ref_imgs,img1s,img2s),0)
                    if damage_initial_previous_frame_mask:
                        try:
                            label1s = damage_masks(label1s)
                        except:
                            label1s = label1s
                            print('damage_error')




                    ##########
                    if self.use_gpu:
                        inputs = inputs.cuda()
                        ref_scribble_labels=ref_scribble_labels.cuda()
                        label1s = label1s.cuda()
                        label2s = label2s.cuda()
                     
                    ##########


                    tmp_dic = self.model(inputs,ref_scribble_labels,label1s,seq_names=seq_names,gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS)

                    label_and_obj_dic={}
                    label_dic={}
                    for i, seq_ in enumerate(seq_names):
                        label_and_obj_dic[seq_]=(label2s[i],obj_nums[i])
                    for seq_ in tmp_dic.keys():
                        tmp_pred_logits = tmp_dic[seq_]
                        tmp_pred_logits = nn.functional.interpolate(tmp_pred_logits,size=(h,w),mode = 'bilinear',align_corners=True)
                        tmp_dic[seq_]=tmp_pred_logits

                        label_tmp,obj_num = label_and_obj_dic[seq_]
                        obj_ids = np.arange(1,obj_num+1)
                        obj_ids = torch.from_numpy(obj_ids)
                        obj_ids = obj_ids.int()
                        if torch.cuda.is_available():
                            obj_ids = obj_ids.cuda()
                        if lossfunc == 'bce':
                            label_tmp = label_tmp.permute(1,2,0)
                            label = (label_tmp.float()==obj_ids.float())
                            label = label.unsqueeze(-1).permute(3,2,0,1)
                            label_dic[seq_]=label.float()
                        elif lossfunc =='cross_entropy':
                            label_dic[seq_]=label_tmp.long()



                    loss = criterion(tmp_dic,label_dic,step)
                    loss =loss/bs
######################################
                     
                    if loss.item()>10000:
                        print(tmp_dic)
                        for k,v in tmp_dic.items():
                            v = v.cpu()
                            v = v.detach().numpy()
                            np.save(k+'.npy',v)
                            l=label_dic[k]
                            l=l.cpu().detach().numpy()
                            np.save('lab'+k+'.npy',l)
                        #continue 
                        exit()
##########################################
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    #scheduler.step()
                    
                    running_loss.update(loss.item(),bs)
                    if step%1==0:
                        #print(torch.cuda.memory_allocated())
                        #print(torch.cuda.max_memory_cached())
                        #torch.cuda.empty_cache()
                        #torch.cuda.reset_max_memory_allocated()
                        print('step:{},now_lr:{} ,loss:{:.4f}({:.4f})'.format(step,now_lr ,running_loss.val,running_loss.avg))
                        #print(tmp_dic)
                        #print(seq_names)
                    #    print('step:{}'.format(step))
                        
                        show_ref_img = ref_imgs.cpu().numpy()[0]
                        show_img1 = img1s.cpu().numpy()[0]
                        show_img2 = img2s.cpu().numpy()[0]

                        mean = np.array([[[0.485]], [[0.456]], [[0.406]]])
                        sigma = np.array([[[0.229]], [[0.224]], [[0.225]]])

                        show_ref_img = show_ref_img*sigma+mean
                        show_img1 = show_img1*sigma+mean
                        show_img2 = show_img2*sigma+mean


                        show_gt = label2s.cpu()[0]

                        show_gt = show_gt.squeeze(0).numpy()
                        show_gtf = label2colormap(show_gt).transpose((2,0,1))

                        ##########
                        show_preds = tmp_dic[seq_names[0]].cpu()
                        show_preds=nn.functional.interpolate(show_preds,size=(h,w),mode = 'bilinear',align_corners=True)
                        show_preds = show_preds.squeeze(0)
                        if lossfunc=='bce':
                            show_preds = (torch.sigmoid(show_preds)>0.5)
                            show_preds_s = torch.zeros((h,w))
                            for i in range(show_preds.size(0)):
                                show_preds_s[show_preds[i]]=i+1
                        elif lossfunc=='cross_entropy':
                            show_preds_s = torch.argmax(show_preds,dim=0)
                        show_preds_s = show_preds_s.numpy()
                        show_preds_sf = label2colormap(show_preds_s).transpose((2,0,1))

                        pix_acc = np.sum(show_preds_s==show_gt)/(h*w)




                        tblogger.add_scalar('loss', running_loss.avg, step)
                        tblogger.add_scalar('pix_acc', pix_acc, step)
                        tblogger.add_scalar('now_lr', now_lr, step)
                        tblogger.add_image('Reference image', show_ref_img, step)
                        tblogger.add_image('Previous frame image', show_img1, step)
                        tblogger.add_image('Current frame image', show_img2, step)
                        tblogger.add_image('Groud Truth', show_gtf, step)
                        tblogger.add_image('Predict label', show_preds_sf, step)




                        ###########TODO
                    if step%5000==0 and step!=0:
                        self.save_network(self.model,step)



                    step+=1
Example #5
0
def forward(data_loader, model, criterion, epoch=0, training=True, optimizer=None):
    if args.gpus and len(args.gpus) > 1:
        model=torch.nn.DataParallel(model, args.gpus)
    batch_time=AverageMeter()
    data_time=AverageMeter()
    losses=AverageMeter()
    top1=AverageMeter()
    top5=AverageMeter()

    end=time.time()
    for i, (inputs, target) in enumerate(data_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        if args.gpus is not None:
            target=target.cuda(async=True)
        input_var=Variable(inputs.type(args.type), volatile=not training)
        target_var=Variable(target)

        # compute output
        output=model(input_var)
        loss=criterion(output, target_var)
        if type(output) is list:
            output=output[0]

        # measure accuracy and record loss
        prec1, prec5=accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.data[0], inputs.size(0))
        top1.update(prec1[0], inputs.size(0))
        top5.update(prec5[0], inputs.size(0))

        if training:
            optimizer.update(epoch, epoch * len(data_loader) + i)
            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end=time.time()

        if i % args.print_freq == 0:
            logging.info('{phase} - Epoch: [{0}][{1}/{2}]\t'
                         'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                         'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                         'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                         'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                         'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                             epoch, i, len(data_loader),
                             phase='TRAINING' if training else 'EVALUATING',
                             batch_time=batch_time,
                             data_time=data_time, loss=losses, top1=top1, top5=top5))

    return losses.avg, top1.avg, top5.avg
Example #6
0
    def train(self,damage_initial_previous_frame_mask=True,lossfunc='cross_entropy',model_resume=False,eval_total=False,init_prev=False):
        ###################
        interactor = interactive_robot.InteractiveScribblesRobot()
        self.model.train()
        running_loss = AverageMeter()
        optimizer = optim.SGD(self.model.inter_seghead.parameters(),lr=cfg.TRAIN_LR,momentum=cfg.TRAIN_MOMENTUM,weight_decay=cfg.TRAIN_WEIGHT_DECAY)
#        optimizer = optim.SGD(self.model.inter_seghead.parameters(),lr=cfg.TRAIN_LR,momentum=cfg.TRAIN_MOMENTUM)
        #scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=cfg.TRAIN_LR_STEPSIZE,gamma=cfg.TRAIN_LR_GAMMA)
        

        ###################

        composed_transforms = transforms.Compose([tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP),
                                                    tr.RandomScale(),
                                                     tr.RandomCrop((cfg.DATA_RANDOMCROP,cfg.DATA_RANDOMCROP),10),
                                                     tr.Resize(cfg.DATA_RESCALE),
                                                     tr.ToTensor()])
        print('dataset processing...')
        train_dataset = DAVIS2017_Train(root=cfg.DATA_ROOT, transform=composed_transforms)
        train_list = train_dataset.seqs

#        train_dataset = DAVIS2017_VOS_Train(root=cfg.DATA_ROOT, transform=composed_transforms)

#        trainloader = DataLoader(train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE,
#                        shuffle=True,num_workers=cfg.NUM_WORKER,pin_memory=True)
        print('dataset processing finished.')
        if lossfunc=='bce':
            criterion = Added_BCEWithLogitsLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS,cfg.TRAIN_HARD_MINING_STEP)
        elif lossfunc=='cross_entropy':
            criterion = Added_CrossEntropyLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS,cfg.TRAIN_HARD_MINING_STEP)
        else:
            print('unsupported loss funciton. Please choose from [cross_entropy,bce]')

        max_itr = cfg.TRAIN_TOTAL_STEPS

        step=0
        round_=3
        epoch_per_round=30
        if model_resume:
            saved_model_=os.path.join(self.save_res_dir,'save_step_75000.pth')

            saved_model_ = torch.load(saved_model_)
            self.model=self.load_network(self.model,saved_model_)
            step=75000
            print('resume from step {}'.format(step))
        while step<cfg.TRAIN_TOTAL_STEPS:
            
            if step>100001:
                break

            for r in range(round_):
                if r==0:
                    print('start new')
                    global_map_tmp_dic={}
                    train_dataset.transform=transforms.Compose([tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP),
                                                    tr.RandomScale(),
                                                     tr.RandomCrop((cfg.DATA_RANDOMCROP,cfg.DATA_RANDOMCROP)),
                                                     tr.Resize(cfg.DATA_RESCALE),
                                                     tr.ToTensor()])
                    train_dataset.init_ref_frame_dic()

                trainloader = DataLoader(train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE,
                        sampler = RandomIdentitySampler(train_dataset.sample_list), 
                        shuffle=False,num_workers=cfg.NUM_WORKER,pin_memory=True)
                print('round:{} start'.format(r))
                for epoch in range(epoch_per_round):



                    for ii, sample in enumerate(trainloader):
                        now_lr=self._adjust_lr(optimizer,step,max_itr)
                        ref_imgs = sample['ref_img'] #batch_size * 3 * h * w
                        #img1s = sample['img1'] 
                        #img2s = sample['img2']
                        ref_scribble_labels = sample['ref_scribble_label'] #batch_size * 1 * h * w
                        #label1s = sample['label1']
                        #label2s = sample['label2']
                        seq_names = sample['meta']['seq_name'] 
                        obj_nums = sample['meta']['obj_num']
                        #frame_nums = sample['meta']['frame_num']
                        ref_frame_nums = sample['meta']['ref_frame_num']
                        ref_frame_gts=sample['ref_frame_gt']
                        bs,_,h,w = ref_imgs.size()
#                        print(ref_imgs.size())

#                        if r==0:
#                            ref_scribble_labels=self.rough_ROI(ref_scribble_labels)
                        ##########
                        if self.use_gpu:
                            inputs = ref_imgs.cuda()
                        
                            ref_scribble_labels=ref_scribble_labels.cuda()
                            ref_frame_gts = ref_frame_gts.cuda()
                            #label1s = label1s.cuda()
                            #label2s = label2s.cuda()
                        #print(inputs.size()) 
                        ##########        
                        with torch.no_grad():
                            self.model.feature_extracter.eval()
                            self.model.semantic_embedding.eval()
                            ref_frame_embedding = self.model.extract_feature(inputs)
                        if r==0:
                            first_inter=True

                            tmp_dic = self.model.int_seghead(ref_frame_embedding=ref_frame_embedding,ref_scribble_label=ref_scribble_labels,
                                    prev_round_label=None,normalize_nearest_neighbor_distances=True,global_map_tmp_dic={},
                                 seq_names=seq_names,gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS,
                                 frame_num=ref_frame_nums,first_inter=first_inter)
                        else:
                            first_inter=False
                            prev_round_label=sample['prev_round_label']
                        #    print(prev_round_label.size())
                            #prev_round_label=prev_round_label_dic[seq_names[0]]
                            prev_round_label=prev_round_label.cuda()
                            tmp_dic = self.model.int_seghead(ref_frame_embedding=ref_frame_embedding,ref_scribble_label=ref_scribble_labels,
                                    prev_round_label=prev_round_label,normalize_nearest_neighbor_distances=True,global_map_tmp_dic={},
                                 seq_names=seq_names,gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS,
                                 frame_num=ref_frame_nums,first_inter=first_inter)




                        label_and_obj_dic={}
                        label_dic={}
                        for i, seq_ in enumerate(seq_names):
                            label_and_obj_dic[seq_]=(ref_frame_gts[i],obj_nums[i])
                        for seq_ in tmp_dic.keys():
                            tmp_pred_logits = tmp_dic[seq_]
                            tmp_pred_logits = nn.functional.interpolate(tmp_pred_logits,size=(h,w),mode = 'bilinear',align_corners=True)
                            tmp_dic[seq_]=tmp_pred_logits        

                            label_tmp,obj_num = label_and_obj_dic[seq_]
                            obj_ids = np.arange(0,obj_num+1)
                            obj_ids = torch.from_numpy(obj_ids)
                            obj_ids = obj_ids.int()
                            if torch.cuda.is_available():
                                obj_ids = obj_ids.cuda()
                            if lossfunc == 'bce':
                                label_tmp = label_tmp.permute(1,2,0)
                                label = (label_tmp.float()==obj_ids.float())
                                label = label.unsqueeze(-1).permute(3,2,0,1)
                                label_dic[seq_]=label.float()
                            elif lossfunc =='cross_entropy':
                                label_dic[seq_]=label_tmp.long()        
        
        
                        loss = criterion(tmp_dic,label_dic,step)
                        loss =loss/bs
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()
                        #scheduler.step()
                        

                        running_loss.update(loss.item(),bs)
                        if step%50==0:
                            #print(torch.cuda.memory_allocated())
                            #print(torch.cuda.max_memory_cached())
                            torch.cuda.empty_cache()
                            #torch.cuda.reset_max_memory_allocated()
                            print('step:{},now_lr:{} ,loss:{:.4f}({:.4f})'.format(step,now_lr ,running_loss.val,running_loss.avg))
                        #    print('step:{}'.format(step))
                            
                            show_ref_img = ref_imgs.cpu().numpy()[0]
                            #show_img1 = img1s.cpu().numpy()[0]
                            #show_img2 = img2s.cpu().numpy()[0]        

                            mean = np.array([[[0.485]], [[0.456]], [[0.406]]])
                            sigma = np.array([[[0.229]], [[0.224]], [[0.225]]])        

                            show_ref_img = show_ref_img*sigma+mean
                            #show_img1 = show_img1*sigma+mean
                            #show_img2 = show_img2*sigma+mean        
        

                            #show_gt = label2s.cpu()[0]        

                            show_gt = ref_frame_gts.cpu()[0].squeeze(0).numpy()
                            show_gtf = label2colormap(show_gt).transpose((2,0,1))        
                            show_scrbble=ref_scribble_labels.cpu()[0].squeeze(0).numpy()
                            show_scrbble=label2colormap(show_scrbble).transpose((2,0,1))
                            if r!=0:
                                show_prev_round_label=prev_round_label.cpu()[0].squeeze(0).numpy()
                                show_prev_round_label=label2colormap(show_prev_round_label).transpose((2,0,1))
                            else:
                                show_prev_round_label = np.zeros_like(show_gt)
                                
                                show_prev_round_label = label2colormap(show_prev_round_label).transpose((2,0,1))

                            ##########
                            show_preds = tmp_dic[seq_names[0]].cpu()
                            show_preds=nn.functional.interpolate(show_preds,size=(h,w),mode = 'bilinear',align_corners=True)
                            show_preds = show_preds.squeeze(0)
                            if lossfunc=='bce':
                                show_preds = show_preds[1:]

                                show_preds = (torch.sigmoid(show_preds)>0.5)
                                marker = torch.argmax(show_preds,dim=0)
                                show_preds_s = torch.zeros((h,w))
                                for i in range(show_preds.size(0)):
                                    tmp_mask = (marker==i) & (show_preds[i]>0.5)
                                    show_preds_s[tmp_mask]=i+1
                            elif lossfunc=='cross_entropy':
                                show_preds_s = torch.argmax(show_preds,dim=0)
                            show_preds_s = show_preds_s.numpy()
                            show_preds_sf = label2colormap(show_preds_s).transpose((2,0,1))        

                            pix_acc = np.sum(show_preds_s==show_gt)/(h*w)        
        
        
        
                            if cfg.TRAIN_TBLOG:
                                tblogger.add_scalar('loss', running_loss.avg, step)
                                tblogger.add_scalar('pix_acc', pix_acc, step)
                                tblogger.add_scalar('now_lr', now_lr, step)
                                tblogger.add_image('Reference image', show_ref_img, step)
                                #tblogger.add_image('Previous frame image', show_img1, step)
                                
                           #     tblogger.add_image('Current frame image', show_img2, step)
                                tblogger.add_image('Groud Truth', show_gtf, step)
                                tblogger.add_image('Predict label', show_preds_sf, step)        
        
                                tblogger.add_image('Scribble', show_scrbble, step)
                                tblogger.add_image('prev_round_label', show_prev_round_label, step)

                            ###########TODO
                        if step%20000==0 and step!=0:
                            self.save_network(self.model,step)        
        
        
                
                
                        step+=1
                print('trainset evaluating...')
                print('*'*100)


                if cfg.TRAIN_INTER_USE_TRUE_RESULT:
                    if r!=round_-1:
                        if r ==0:
                            prev_round_label_dic={}
                        self.model.eval()
                        with torch.no_grad():
                            round_scribble={}

                            frame_num_dic= {}
                            train_dataset.transform=transforms.Compose([tr.Resize(cfg.DATA_RESCALE),tr.ToTensor()])
#                            train_dataset.transform=composed_transforms
                            trainloader = DataLoader(train_dataset,batch_size=1,
                                sampler = RandomIdentitySampler(train_dataset.sample_list), 
                                shuffle=False,num_workers=cfg.NUM_WORKER,pin_memory=True)
                            for ii, sample in enumerate(trainloader):
                                ref_imgs = sample['ref_img'] #batch_size * 3 * h * w
                                img1s = sample['img1'] 
                                img2s = sample['img2']
                                ref_scribble_labels = sample['ref_scribble_label'] #batch_size * 1 * h * w
                                label1s = sample['label1']
                                label2s = sample['label2']
                                seq_names = sample['meta']['seq_name'] 
                                obj_nums = sample['meta']['obj_num']
                                frame_nums = sample['meta']['frame_num']
                                bs,_,h,w = img2s.size()
                                inputs = torch.cat((ref_imgs,img1s,img2s),0)
                                if r==0:
                                    ref_scribble_labels=self.rough_ROI(ref_scribble_labels)
                                print(seq_names[0])
#                                    if damage_initial_previous_frame_mask:
#                                        try:
#                                            label1s = damage_masks(label1s)
#                                        except:
#                                            label1s = label1s
#                                            print('damage_error') 
                                label1s_tocat=torch.Tensor()
                                for i in range(bs):
                                    l = label1s[i]
                                    l = l.unsqueeze(0)
                                    l = mask_damager(l,0.0)
                                    l = torch.from_numpy(l)

                                    l = l.unsqueeze(0).unsqueeze(0)
                        
                                    label1s_tocat = torch.cat((label1s_tocat,l.float()),0)
                                label1s = label1s_tocat
                                if self.use_gpu:
                                    inputs = inputs.cuda()
                                    ref_scribble_labels=ref_scribble_labels.cuda()
                                    label1s = label1s.cuda()
                                    
                                tmp_dic, global_map_tmp_dic = self.model(inputs,ref_scribble_labels,label1s,seq_names=seq_names,
                                                            gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS,global_map_tmp_dic=global_map_tmp_dic,
                                                           frame_num=frame_nums)
                                pred_label = tmp_dic[seq_names[0]].detach().cpu()
                                pred_label = nn.functional.interpolate(pred_label,size=(h,w),mode = 'bilinear',align_corners=True)            

                                pred_label=torch.argmax(pred_label,dim=1)
                                pred_label= pred_label.unsqueeze(0)
                                try:
                                    pred_label=damage_masks(pred_label)
                                except:
                                    pred_label=pred_label
                                pred_label=pred_label.squeeze(0)
                                round_scribble[seq_names[0]]=interactor.interact(seq_names[0],pred_label.numpy(),label2s.float().squeeze(0).numpy(),obj_nums)
                                frame_num_dic[seq_names[0]]=frame_nums[0]        
                                pred_label=pred_label.unsqueeze(0)
                                img_ww=Image.open(os.path.join(cfg.DATA_ROOT,'JPEGImages/480p/',seq_names[0],'00000.jpg'))
                                img_ww=np.array(img_ww)
                                or_h,or_w = img_ww.shape[:2]
                                pred_label = torch.nn.functional.interpolate(pred_label.float(),(or_h,or_w),mode='nearest')

                                prev_round_label_dic[seq_names[0]]=pred_label.squeeze(0)
#                                torch.cuda.empty_cache() 
                        train_dataset.update_ref_frame_and_label(round_scribble,frame_num_dic,prev_round_label_dic)
                    print('trainset evaluating finished!')
                    print('*'*100)
                    self.model.train()
                    print('updating ref frame and label')


                    train_dataset.transform=composed_transforms 
                    print('updating ref frame and label finished!')



                else:
                    if r!=round_-1:
                        round_scribble={}    

                        if r ==0:
                            prev_round_label_dic={}
                        #    eval_global_map_tmp_dic={}
                        frame_num_dic= {}
                        train_dataset.transform=tr.ToTensor()
#                        train_dataset.transform=tr.ToTensor()
                        trainloader = DataLoader(train_dataset,batch_size=1,
                            sampler = RandomIdentitySampler(train_dataset.sample_list), 
                            shuffle=False,num_workers=1,pin_memory=True)
                        
                        self.model.eval()
                        with torch.no_grad():
                            for ii, sample in enumerate(trainloader):
                                ref_imgs = sample['ref_img'] #batch_size * 3 * h * w
                                img1s = sample['img1'] 
                                img2s = sample['img2']
                                ref_scribble_labels = sample['ref_scribble_label'] #batch_size * 1 * h * w
                                label1s = sample['label1']
                                label2s = sample['label2']
                                seq_names = sample['meta']['seq_name'] 
                                obj_nums = sample['meta']['obj_num']
                                frame_nums = sample['meta']['frame_num']
                                bs,_,h,w = img2s.size()
                                
 #                               inputs=torch.cat((ref_imgs,img1s,img2s),0)
#                                if r==0:
#                                    ref_scribble_labels=self.rough_ROI(ref_scribble_labels)
                                print(seq_names[0])
                                label2s_ = mask_damager(label2s,0.1)
                                # inputs = inputs.cuda()
                                # ref_scribble_labels=ref_scribble_labels.cuda()
                                # label1s = label1s.cuda()
                                
                                  
                                        
                                # tmp_dic, global_map_tmp_dic = self.model(inputs,ref_scribble_labels,label1s,seq_names=seq_names,
                                #                                 gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS,global_map_tmp_dic=global_map_tmp_dic,
                                #                                frame_num=frame_nums)    

                                #label2s_show = label2s.squeeze().numpy()
                                #label2s_im = Image.fromarray(label2s_show.astype('uint8')).convert('P')
                                #label2s_im.putpalette(_palette)
                                #label2s_im.save('label.png')
                                #label2s__show = label2s_
                                #label2s_im_ = Image.fromarray(label2s__show.astype('uint8')).convert('P')
                                #label2s_im_.putpalette(_palette)
                                #label2s_im_.save('damage_label.png')
                        #        exit()
                                
                                                            
                                #print(label2s_)
                                #print(label2s.size())
                                round_scribble[seq_names[0]]=interactor.interact(seq_names[0],np.expand_dims(label2s_,axis=0),label2s.float().squeeze(0).numpy(),obj_nums)
                                label2s__=torch.from_numpy(label2s_)
                                # img_ww=Image.open(os.path.join(cfg.DATA_ROOT,'JPEGImages/480p/',seq_names[0],'00000.jpg'))
                                # img_ww=np.array(img_ww)
                                # or_h,or_w = img_ww.shape[:2]
                                # label2s__=label2s__.unsqueeze(0).unsqueeze(0)        
                                # label2s__ = torch.nn.functional.interpolate(label2s__.float(),(or_h,or_w),mode='nearest')
                                # label2s__=label2s__.squeeze(0)
#                                print(label2s__.size())

                                frame_num_dic[seq_names[0]]=frame_nums[0]
                                prev_round_label_dic[seq_names[0]]=label2s__    
    

                                #torch.cuda.empty_cache()     
                        print('trainset evaluating finished!')
                        print('*'*100)
                        print('updating ref frame and label')    

                        train_dataset.update_ref_frame_and_label(round_scribble,frame_num_dic,prev_round_label_dic)
                        self.model.train()
                        train_dataset.transform=composed_transforms 
                        print('updating ref frame and label finished!')
Example #7
0
class Solver(object):
    def __init__(self, opt, net):
        self.opt = opt
        self.net = net
        self.loss = AverageMeter('loss')
        self.acc = AverageMeter('acc')

    def fit(self, train_data, test_data, num_query, optimizer, criterion,
            lr_scheduler):
        best_rank1 = -np.inf
        for epoch in range(self.opt.train.num_epochs):
            self.loss.reset()
            self.acc.reset()
            self.net.train()
            # update learning rate
            lr = lr_scheduler.update(epoch)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            logging.info('learning rate update to {:.3e}'.format(lr))

            tic = time.time()
            btic = time.time()
            for i, inputs in enumerate(train_data):
                data, pids, _ = inputs
                label = pids.cuda()
                score, feat = self.net(data)
                loss = criterion(score, feat, label)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                self.loss.update(loss.item())
                acc = (score.max(1)[1] == label.long()).float().mean().item()
                self.acc.update(acc)

                log_interval = self.opt.misc.log_interval
                if log_interval and not (i + 1) % log_interval:
                    loss_name, loss_value = self.loss.get()
                    metric_name, metric_value = self.acc.get()
                    logging.info(
                        'Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\t'
                        '%s=%f' %
                        (epoch, i + 1, train_data.batch_size * log_interval /
                         (time.time() - btic), loss_name, loss_value,
                         metric_name, metric_value))
                    btic = time.time()

            loss_name, loss_value = self.loss.get()
            metric_name, metric_value = self.acc.get()
            throughput = int(train_data.batch_size * len(train_data) /
                             (time.time() - tic))

            logging.info(
                '[Epoch %d] training: %s=%f\t%s=%f' %
                (epoch, loss_name, loss_value, metric_name, metric_value))
            logging.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f' %
                         (epoch, throughput, time.time() - tic))

            is_best = False
            if test_data is not None and self.opt.misc.eval_step and not (
                    epoch + 1) % self.opt.misc.eval_step:
                rank1 = self.test_func(test_data, num_query)
                is_best = rank1 > best_rank1
                if is_best:
                    best_rank1 = rank1
            state_dict = self.net.module.state_dict()
            if not (epoch + 1) % self.opt.misc.save_step:
                save_checkpoint(
                    {
                        'state_dict': state_dict,
                        'epoch': epoch + 1,
                    },
                    is_best=is_best,
                    save_dir=self.opt.misc.save_dir,
                    filename=self.opt.network.name + str(epoch + 1) +
                    '.pth.tar')

    def test_func(self, test_data, num_query):
        self.net.eval()
        feat, person, camera = list(), list(), list()
        for inputs in test_data:
            data, pids, camids = inputs
            with torch.no_grad():
                outputs = self.net(data).cpu()
            feat.append(outputs)
            person.extend(pids.numpy())
            camera.extend(camids.numpy())
        feat = torch.cat(feat, 0)
        qf = feat[:num_query]
        q_pids = np.asarray(person[:num_query])
        q_camids = np.asarray(camera[:num_query])
        gf = feat[num_query:]
        g_pids = np.asarray(person[num_query:])
        g_camids = np.asarray(camera[num_query:])

        logging.info(
            "Extracted features for query set, obtained {}-by-{} matrix".
            format(qf.shape[0], qf.shape[1]))
        logging.info(
            "Extracted features for gallery set, obtained {}-by-{} matrix".
            format(gf.shape[0], gf.shape[1]))

        logging.info("Computing distance matrix")

        m, n = qf.shape[0], gf.shape[0]
        distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
              torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
        distmat.addmm_(1, -2, qf, gf.t())
        distmat = distmat.numpy()

        logging.info("Computing CMC and mAP")
        cmc, mAP = self.eval_func(distmat, q_pids, g_pids, q_camids, g_camids)

        print("Results ----------")
        print("mAP: {:.1%}".format(mAP))
        print("CMC curve")
        for r in [1, 5, 10]:
            print("Rank-{:<3}: {:.1%}".format(r, cmc[r - 1]))
        print("------------------")
        return cmc[0]

    @staticmethod
    def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50):
        """Evaluation with market1501 metric
            Key: for each query identity, its gallery images from the same camera view are discarded.
            """
        num_q, num_g = distmat.shape
        if num_g < max_rank:
            max_rank = num_g
            print("Note: number of gallery samples is quite small, got {}".
                  format(num_g))
        indices = np.argsort(distmat, axis=1)
        matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)

        # compute cmc curve for each query
        all_cmc = []
        all_AP = []
        num_valid_q = 0.  # number of valid query
        for q_idx in range(num_q):
            # get query pid and camid
            q_pid = q_pids[q_idx]
            q_camid = q_camids[q_idx]

            # remove gallery samples that have the same pid and camid with query
            order = indices[q_idx]
            remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
            keep = np.invert(remove)

            # compute cmc curve
            # binary vector, positions with value 1 are correct matches
            orig_cmc = matches[q_idx][keep]
            if not np.any(orig_cmc):
                # this condition is true when query identity does not appear in gallery
                continue

            cmc = orig_cmc.cumsum()
            cmc[cmc > 1] = 1

            all_cmc.append(cmc[:max_rank])
            num_valid_q += 1.

            # compute average precision
            # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
            num_rel = orig_cmc.sum()
            tmp_cmc = orig_cmc.cumsum()
            tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
            tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
            AP = tmp_cmc.sum() / num_rel
            all_AP.append(AP)

        assert num_valid_q > 0, "Error: all query identities do not appear in gallery"

        all_cmc = np.asarray(all_cmc).astype(np.float32)
        all_cmc = all_cmc.sum(0) / num_valid_q
        mAP = np.mean(all_AP)

        return all_cmc, mAP
Example #8
0
    def evaluate(self, query_loader, gallery_loader, queryinfo, galleryinfo):

        self.cnn_model.eval()

        querypid = queryinfo.pid
        querycamid = queryinfo.camid
        querytranum = queryinfo.tranum

        gallerypid = galleryinfo.pid
        gallerycamid = galleryinfo.camid
        gallerytranum = galleryinfo.tranum

        query_features = self.extract_feature(self.cnn_model, query_loader)

        querylen = len(querypid)
        gallerylen = len(gallerypid)

        # online gallery extraction
        single_distmat = np.zeros((querylen, gallerylen))
        gallery_resize = 0
        gallery_popindex = 0
        gallery_popsize = gallerytranum[gallery_popindex]
        gallery_resfeatures = 0
        gallery_empty = True
        preimgs = 0
        preflows = 0

        # time
        gallery_time = AverageMeter()
        end = time.time()

        for i, (imgs, flows, _, _) in enumerate(gallery_loader):
            imgs = Variable(imgs, volatile=True)
            flows = Variable(flows, volatile=True)
            seqnum = imgs.size(0)
            ##############

            if i == 0:
                preimgs = imgs
                preflows = flows

            if gallery_empty:
                out_feat = self.cnn_model(imgs, flows, self.mode)

                gallery_resfeatures = out_feat.data
                gallery_empty = False

            elif imgs.size(0) < gallery_loader.batch_size:
                flaw_batchsize = imgs.size(0)
                cat_batchsize = gallery_loader.batch_size - flaw_batchsize
                imgs = torch.cat((imgs, preimgs[0:cat_batchsize]), 0)
                flows = torch.cat((flows, preflows[0:cat_batchsize]), 0)
                out_feat = self.cnn_model(imgs, flows, self.mode)
                out_feat = out_feat[0:flaw_batchsize]
                gallery_resfeatures = torch.cat(
                    (gallery_resfeatures, out_feat.data), 0)
            else:
                out_feat = self.cnn_model(imgs, flows, self.mode)
                gallery_resfeatures = torch.cat(
                    (gallery_resfeatures, out_feat.data), 0)

            gallery_resize = gallery_resize + seqnum

            while gallery_popsize <= gallery_resize:
                if (gallery_popindex + 1) % 50 == 0:
                    print('gallery--{:04d}'.format(gallery_popindex))
                if gallery_popsize == 1:
                    gallery_popfeatures = gallery_resfeatures
                else:
                    gallery_popfeatures = gallery_resfeatures[
                        0:gallery_popsize, :]
                if gallery_popsize < gallery_resize:
                    gallery_resfeatures = gallery_resfeatures[
                        gallery_popsize:gallery_resize, :]
                else:
                    gallery_resfeatures = 0
                    gallery_empty = True
                gallery_resize = gallery_resize - gallery_popsize
                distmat_qall_g = pairwise_distance_tensor(
                    query_features, gallery_popfeatures)

                q_start = 0
                for qind, qnum in enumerate(querytranum):
                    distmat_qg = distmat_qall_g[q_start:q_start + qnum, :]
                    distmat_qg = distmat_qg.cpu().numpy()
                    percile = np.percentile(distmat_qg, 20)

                    if distmat_qg[distmat_qg < percile] is not None:
                        distmean = np.mean(distmat_qg[distmat_qg < percile])
                    else:
                        distmean = np.mean(distmat_qg)

                    single_distmat[qind, gallery_popindex] = distmean
                    q_start = q_start + qnum

                gallery_popindex = gallery_popindex + 1

                if gallery_popindex < gallerylen:
                    gallery_popsize = gallerytranum[gallery_popindex]
                gallery_time.update(time.time() - end)
                end = time.time()

        return evaluate_seq(single_distmat, querypid, querycamid, gallerypid,
                            gallerycamid)
Example #9
0
    def train(self, epoch, data_loader, optimizer1, optimizer2):
        self.model.train()

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        precisions = AverageMeter()
        precisions1 = AverageMeter()
        precisions2 = AverageMeter()

        end = time.time()
        for i, inputs in enumerate(data_loader):
            data_time.update(time.time() - end)

            inputs, targets = self._parse_data(inputs)

            loss, prec_oim, prec_score, prec_finalscore = self._forward(
                inputs, targets)
            losses.update(loss.data[0], targets.size(0))

            precisions.update(prec_oim, targets.size(0))
            precisions1.update(prec_score, targets.size(0))
            precisions2.update(prec_finalscore, targets.size(0))

            optimizer1.zero_grad()
            optimizer2.zero_grad()
            loss.backward()
            optimizer1.step()
            optimizer2.step()

            batch_time.update(time.time() - end)
            end = time.time()
            print_freq = 50
            if (i + 1) % print_freq == 0:
                print('Epoch: [{}][{}/{}]\t'
                      'Loss {:.3f} ({:.3f})\t'
                      'prec_oim {:.2%} ({:.2%})\t'
                      'prec_score {:.2%} ({:.2%})\t'
                      'prec_finalscore(total) {:.2%} ({:.2%})\t'.format(
                          epoch, i + 1, len(data_loader), losses.val,
                          losses.avg, precisions.val, precisions.avg,
                          precisions1.val, precisions1.avg, precisions2.val,
                          precisions2.avg))
Example #10
0
def validate(val_loader, model, criterion):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()
    '''print("Validate begin")
    for n, m in self.model.named_modules():
            print(m)'''

    with torch.no_grad():
        end = time.time()
        for i, (input, target) in enumerate(val_loader):
            input = input.to(args.device)
            target = target.to(args.device)
            output = model(input)

            loss = criterion(output, target)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(float(prec1), input.size(0))
            top5.update(float(prec5), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                          i,
                          len(val_loader),
                          batch_time=batch_time,
                          loss=losses,
                          top1=top1,
                          top5=top5))
            #return

        print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'.format(
            top1=top1, top5=top5))

    return losses.avg, top1.avg, top5.avg
Example #11
0
def TrainOneEpoch(train_loader, model, optimizer, criterion, epoch_num,
                  vis_tool, record_value):
    losses = AverageMeter()
    train_eval = ConfusionMeter(num_class=opt.out_dim)

    # calculate the final result
    train_dice = AverageMeter()
    train_recall = AverageMeter()

    best_value = record_value

    model.train()

    for batch_ids, (data, target) in enumerate(train_loader):
        if opt.use_cuda:
            data, target = data.cuda(), target.cuda()

        optimizer.zero_grad()

        output = model(data)

        # calculate the weight of the batch:
        weight = GetWeight(opt, target, slr=0, is_t=0)
        loss = criterion(output, target, weight=weight)

        loss.backward()
        optimizer.step()

        # update the loss value
        losses.update(loss.item())

        # calculate the metrics for evaluation:
        _, pred = torch.max(output, 1)
        train_eval.update(pred, target)

        avg_loss = losses.avg
        dice_value = train_eval.get_scores('Dice')
        recall_value = train_eval.get_scores('Recall')

        train_dice.update(dice_value)
        train_recall.update(recall_value)

        # for visualization
        if batch_ids % opt.train_plotfreq == 0:
            vis_tool.plot('Train_Loss', loss.item())
            vis_tool.plot('Train_Dice', dice_value)
            vis_tool.plot('Train_Recall', recall_value)

        print('Train:Batch_Num:{}  Loss:{:.3f}  Dice:{:.3f}  Recall:{:.3f}'.
              format(batch_ids, loss.item(), dice_value, recall_value))

    return avg_loss, train_dice.avg, train_recall.avg, best_value
Example #12
0
def TrainOneEpoch(train_loader, model, optimizer, criterion, epoch_num,
                  vis_tool, prefix):
    losses = AverageMeter()
    train_eval = ConfusionMeter(num_class=opt.out_dim)

    # calculate the final result
    train_dice = AverageMeter()
    train_recall = AverageMeter()

    model.train()

    for batch_ids, (data, target) in enumerate(train_loader):
        if opt.use_cuda:
            data, target = data.cuda(), target.cuda()

        optimizer.zero_grad()

        output = model(data)
        # calculate the weight of the batch:
        weight = GetWeight(opt, target, slr=0.00001, is_t=0)
        loss = criterion(output, target, weight=weight)

        loss.backward()
        optimizer.step()

        # update the loss value
        losses.update(loss.item())

        # calculate the metrics for evaluation:
        _, pred = torch.max(output, 1)
        train_eval.update(pred, target)

        avg_loss = losses.avg
        dice_value = train_eval.get_scores('Dice')
        recall_value = train_eval.get_scores('Recall')

        train_dice.update(dice_value)
        train_recall.update(recall_value)

        # begin to show the results
        if batch_ids % opt.train_plotfreq == 0:
            vis_tool.plot('Train_Loss', loss.item())
            vis_tool.plot('Train_Dice', dice_value)
            vis_tool.plot('Train_Recall', recall_value)

            # begin to plot the prediction result
            image1 = data.cpu().numpy()[0, 0, ...]
            image1 = image1 * all_std1 + all_mean1
            image1 = np.clip(image1, 150, 350)
            image1 = (image1 - 150) / 200
            image1_mip = np.hstack(
                [np.max(image1, 0),
                 np.max(image1, 1),
                 np.max(image1, 2)])

            # see the pred
            pred1 = pred.cpu().numpy()
            pred1 = pred1[0, ...]
            mip1 = np.hstack(
                [np.max(pred1, 0),
                 np.max(pred1, 1),
                 np.max(pred1, 2)])

            # see the label
            target1 = target.cpu().numpy()
            target1 = target1[0, ...]
            mip2 = np.hstack(
                [np.max(target1, 0),
                 np.max(target1, 1),
                 np.max(target1, 2)])
            mip3 = np.vstack([image1_mip, mip1, mip2])
            vis_tool.img('pred_label', np.uint8(255 * mip3))

        print('Train:Batch_Num:{}  Loss:{:.3f}  Dice:{:.3f}  Recall:{:.3f}'.
              format(batch_ids, loss.item(), dice_value, recall_value))

    return avg_loss, train_dice.avg, train_recall.avg
Example #13
0
    def __eval(self, topk):
        batch_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top2 = AverageMeter()
        top5 = AverageMeter()
        ClassTPs_Top1 = torch.zeros(1, len(self.classes),
                                    dtype=torch.uint8).cuda()
        ClassTPs_Top2 = torch.zeros(1, len(self.classes),
                                    dtype=torch.uint8).cuda()
        ClassTPs_Top5 = torch.zeros(1, len(self.classes),
                                    dtype=torch.uint8).cuda()
        y_preds = []
        y_trues = []

        # Start data time
        data_time_start = time.time()
        #feat = torch.tensor([])
        with torch.no_grad():
            for i, (images, labels, orig_attrs) in enumerate(self.dataloader):
                start_time = time.time()
                if self.use_cuda:
                    images, labels = images.cuda(), labels.cuda()
                if self.ten_crops:
                    bs, ncrops, c, h, w = images.size()
                    images = images.view(-1, c, h, w)

                if self.with_attribute:
                    orig_attrs = orig_attrs.cuda()
                    attrs = orig_attrs.detach().clone()
                    attrs[attrs > self.xi] = 1.
                    attrs[attrs <= self.xi] = 0.
                    outputs, _ = self.model(images, orig_attrs)
                    #f = f.view(bs, ncrops, -1).mean(1)
                    #print('Getting features {}'.format(f.shape))
                    #feat = torch.cat([feat, f.cpu()], dim=0)
                else:
                    outputs = self.model(images, orig_attrs)

                if self.ten_crops:
                    outputs = outputs.view(bs, ncrops, -1).mean(1)

                loss = self.criterion(outputs, labels)

                y_pred = outputs.argmax(dim=1)
                y_trues = np.append(y_trues, labels.cpu().numpy(), axis=0)
                y_preds = np.append(y_preds, y_pred.cpu().numpy(), axis=0)

                # Compute class accuracy
                ClassTPs = getclassAccuracy(outputs, labels, len(self.classes),
                                            topk)
                ClassTPs_Top1 += ClassTPs[0]
                ClassTPs_Top2 += ClassTPs[1]
                ClassTPs_Top5 += ClassTPs[2]

                # Measure Top1, Top2 and Top5 accuracy
                prec1, prec2, prec5 = accuracy(outputs.data, labels.data, topk)

                losses.update(loss.item(), labels.size(0))
                top1.update(prec1.item(), labels.size(0))
                top2.update(prec2.item(), labels.size(0))
                top5.update(prec5.item(), labels.size(0))

                batch_time.update(time.time() - start_time)
                if (i + 1) % 10 == 0:
                    print('Testing batch: [{}/{}]\t'
                          'Loss {loss.val:.3f} (avg: {loss.avg:.3f})\t'
                          'Prec@1 {top1.val:.3f} (avg: {top1.avg:.3f})\t'
                          'Prec@2 {top2.val:.3f} (avg: {top2.avg:.3f})\t'
                          'Prec@5 {top5.val:.3f} (avg: {top5.avg:.3f})'.format(
                              i,
                              len(self.dataloader),
                              batch_time=batch_time,
                              loss=losses,
                              top1=top1,
                              top2=top2,
                              top5=top5))

            ClassTPDic = {
                'Top1': ClassTPs_Top1.cpu().numpy(),
                'Top2': ClassTPs_Top2.cpu().numpy(),
                'Top5': ClassTPs_Top5.cpu().numpy()
            }

            print(
                'Elapsed time for {} set evaluation {time:.3f} seconds'.format(
                    set, time=time.time() - data_time_start))
            print("")
            print(
                metrics.precision_score(y_true=y_trues,
                                        y_pred=y_preds,
                                        average='micro'))
            #np.savez('/home/paul/feat.npz', feat.numpy(), np.array(y_trues))
            return top1.avg, top2.avg, top5.avg, losses.avg, ClassTPDic
Example #14
0
def train_using_metriclearning_with_inception3(model,
                                               optimizer,
                                               criterion,
                                               epoch,
                                               train_root,
                                               train_pictures,
                                               prefix,
                                               distance_dict=None,
                                               class_to_nearest_class=None):
    start = time.time()
    model.train()

    losses = AverageMeter()
    is_add_margin = False
    feature_util = FeatureUtil(G.WIDTH, G.HEIGHT)
    if train_pictures is None:
        train_pictures = os.listdir(train_root)
    log_freq = int(len(train_pictures) / 6)

    anchor_ls, positive_ls, negative_ls = [], [], []
    for i, picture_path in enumerate(train_pictures):
        cls_idx = picture_path.split('_')[-1][:-4]

        anchor_input = feature_util.get_proper_input(os.path.join(
            train_root, picture_path),
                                                     ls_form=True)
        anchor_ls.append(anchor_input)

        hard_sample = random.randint(
            1, 2) % 2 == 0  # decide if use random sample or hard sample
        if hard_sample and distance_dict is not None and class_to_nearest_class is not None:
            random_int = random.randint(0, 39)
            random_int = min(random_int, len(distance_dict[cls_idx]) - 1)
            p_input_pic = distance_dict[cls_idx][random_int][0] if len(
                distance_dict[cls_idx][random_int]) > 0 else picture_path
        else:
            p_pictures = [
                x for x in train_pictures
                if x.split('_')[-1][:-4] == cls_idx and x != picture_path
            ]
            random.shuffle(p_pictures)
            p_input_pic = p_pictures[0] if len(
                p_pictures) > 0 else picture_path
        p_input = feature_util.get_proper_input(os.path.join(
            train_root, p_input_pic),
                                                ls_form=True)
        positive_ls.append(p_input)

        if hard_sample and distance_dict is not None and class_to_nearest_class is not None:
            n_pictures = [
                x for x in train_pictures
                if x.split('_')[-1][:-4] == class_to_nearest_class[cls_idx]
            ]
            if len(n_pictures) == 0:
                n_pictures = [
                    x for x in train_pictures
                    if x.split('_')[-1][:-4] != cls_idx
                ]
        else:
            n_pictures = [
                x for x in train_pictures if x.split('_')[-1][:-4] != cls_idx
            ]
        random.shuffle(n_pictures)
        n_input_pic = n_pictures[0]
        n_input = feature_util.get_proper_input(os.path.join(
            train_root, n_input_pic),
                                                ls_form=True)
        negative_ls.append(n_input)

        if ((i + 1) == len(train_pictures)) or ((i + 1) % 128 == 0):

            anchor_ls = torch.Tensor(anchor_ls)
            positive_ls = torch.Tensor(positive_ls)
            negative_ls = torch.Tensor(negative_ls)

            anchor_ls = anchor_ls.cuda()
            positive_ls = positive_ls.cuda()
            negative_ls = negative_ls.cuda()

            anchor_ls = model(anchor_ls)
            positive_ls = model(positive_ls)
            negative_ls = model(negative_ls)

            loss = criterion(anchor_ls, positive_ls, negative_ls)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            losses.update(loss.item())
            if losses.val < 1e-5:
                is_add_margin = True
            """
                reset for next training "batch"
            """
            anchor_ls, positive_ls, negative_ls = [], [], []

        if (i + 1) % log_freq == 0:
            print('Epoch: {}[{}/{}]\t'
                  'Loss {:.6f} ({:.6f})\t'.format(epoch, i + 1,
                                                  len(train_pictures),
                                                  losses.val, losses.mean))

    time_token = time.time() - start

    param_group = optimizer.param_groups
    print('Epoch: [{}]\tEpoch Time {:.1f} s\tLoss {:.6f}\t'
          'Lr {:.2e}'.format(epoch, time_token, losses.mean,
                             param_group[0]['lr']))
    return is_add_margin
Example #15
0
def validate(val_loader, model, criterion, device):
    batch_time = AverageMeter()
    losses = AverageMeter()
    topk = [AverageMeter() for i in range(4)]

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (input, target) in enumerate(val_loader):
            if input.dim() > 4:
                input = input.reshape(input.shape[0] * input.shape[1],
                                      input.shape[2], input.shape[3],
                                      input.shape[4])
                #target  = target.float()
                target = target.reshape(target.shape[0] * target.shape[1],
                                        target.shape[2])

            input = input.to(device)
            target = target.to(device)
            target = target.float()
            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss

            output = output.cpu()
            target = target.cpu()
            prec = accuracy(output, target, topk=4)
            for i in range(4):
                topk[i].update(prec[i], input.size(0))
            losses.update(loss.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\n'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@2 {top2.val:.3f} ({top2.avg:.3f})\t'
                      'Prec@3 {top3.val:.3f} ({top3.avg:.3f})\t'
                      'Prec@4 {top4.val:.3f} ({top4.avg:.3f})\t'.format(
                          i,
                          len(val_loader),
                          batch_time=batch_time,
                          loss=losses,
                          top1=topk[0],
                          top2=topk[1],
                          top3=topk[2],
                          top4=topk[3]))

        print(
            ' * Prec@1 {top1.avg:.3f} Prec@2 {top2.avg:.3f} Prec@3 {top3.avg:.3f} Prec@4 {top4.avg:.3f}'
            .format(top1=topk[0], top2=topk[1], top3=topk[2], top4=topk[3]))

    return topk[0].avg
Example #16
0
class ReidSystem():
    def __init__(self, cfg, logger, writer):
        self.cfg, self.logger, self.writer = cfg, logger, writer
        # Define dataloader
        self.tng_dataloader, self.val_dataloader, self.num_classes, self.num_query = get_dataloader(
            cfg)
        # networks
        self.model = build_model(cfg, self.num_classes)
        self.base_type = self.model.base_type
        # loss function

        if cfg.SOLVER.LABEL_SMOOTH:
            self.ce_loss = CrossEntropyLabelSmooth(self.num_classes)
        else:
            self.ce_loss = nn.CrossEntropyLoss()
        self.triplet = TripletLoss(cfg.SOLVER.MARGIN)
        self.aligned_triplet = TripletLossAlignedReID(margin=cfg.SOLVER.MARGIN)
        self.of_penalty = OFPenalty(beta=1e-6,
                                    penalty_position=['intermediate'])

        # optimizer and scheduler
        self.opt = make_optimizer(self.cfg, self.model)
        self.lr_sched = make_lr_scheduler(self.cfg, self.opt)

        self._construct()

    def _construct(self):
        self.global_step = 0
        self.current_epoch = 0
        self.batch_nb = 0
        self.max_epochs = self.cfg.SOLVER.MAX_EPOCHS
        self.log_interval = self.cfg.SOLVER.LOG_INTERVAL
        self.eval_period = self.cfg.SOLVER.EVAL_PERIOD
        self.use_dp = False
        self.use_ddp = False

    def loss_fns(self, outputs, labels):
        if self.cfg.MODEL.FINE_TUNE:
            triplet_loss = self.triplet(outputs, labels)[0]
            return {'global_triplet_loss': triplet_loss}
        elif self.cfg.SOLVER.TRIPLET_ONLY:
            triplet_loss = self.triplet(outputs[1], labels)[0]
            return {'global_triplet_loss': triplet_loss}
        else:
            ce_loss = self.ce_loss(outputs[0], labels)
            triplet_loss = self.triplet(outputs[1], labels)[0]
            return {'ce_loss': ce_loss, 'global_triplet_loss': triplet_loss}

    def aligned_loss_fns(self, outputs, labels):
        """

        :param outputs: [cls_score, global_feature, local_feature]
        :param labels: person IDs
        :return:
        """
        ce_loss = self.ce_loss(outputs[0], labels)
        global_triplet_loss, local_triplet_loss = self.aligned_triplet(
            outputs[1], labels, outputs[2])
        #return {'ce_loss': ce_loss, 'globaltriplet': triplet_loss}
        return {
            'ce_loss': ce_loss,
            'global_triplet_loss': global_triplet_loss,
            'local_triplet_loss': local_triplet_loss
        }

    def mgn_loss_fns(self, outputs, labels):
        triplet_loss = [
            self.triplet(output, labels)[0] for output in outputs[1]
        ]
        triplet_loss = sum(triplet_loss) / len(triplet_loss)

        ce_loss = [self.ce_loss(output, labels) for output in outputs[2]]
        ce_loss = sum(ce_loss) / len(ce_loss)

        return {'ce_loss': ce_loss, 'global_triplet_loss': triplet_loss}

    def on_train_begin(self):
        self.start_epoch = 0
        self.best_mAP = -np.inf
        self.running_loss = AverageMeter()
        self.running_CE_loss = AverageMeter()
        self.running_GT_loss = AverageMeter()
        self.running_LT_loss = AverageMeter()
        self.running_OF_loss = AverageMeter()
        log_save_dir = os.path.join(self.cfg.OUTPUT_DIR,
                                    self.cfg.DATASETS.TEST_NAMES,
                                    self.cfg.MODEL.VERSION)
        self.model_save_dir = os.path.join(log_save_dir, 'ckpts')
        if not os.path.exists(self.model_save_dir):
            os.makedirs(self.model_save_dir)

        #######
        # Load checkpoints
        cfg = self.cfg
        if cfg.MODEL.CHECKPOINT is not '':
            self.load_checkpoint(cfg.MODEL.CHECKPOINT,
                                 with_optimizer=not cfg.MODEL.FINE_TUNE)
            #self.logger.info('continue training')
        ######

        self.gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
        self.use_dp = (len(self.gpus) > 0) and (self.cfg.MODEL.DIST_BACKEND
                                                == 'dp')

        if self.use_dp:
            self.model = nn.DataParallel(self.model)

        self.model = self.model.cuda()

        self.model.train()

    def on_epoch_begin(self):
        self.batch_nb = 0
        self.current_epoch += 1
        self.t0 = time.time()
        self.running_loss.reset()
        self.running_CE_loss.reset()
        self.running_GT_loss.reset()
        self.running_LT_loss.reset()
        self.running_OF_loss.reset()

        self.tng_prefetcher = data_prefetcher(self.tng_dataloader, self.cfg)

    def training_step(self, batch):
        inputs, labels, _ = batch
        outputs = self.model(inputs, labels)

        if self.base_type in [
                BASE_ALIGNED_RESNET50, BASE_ALIGNED_RESNET101,
                BASE_ALIGNED_RESNEXT101, BASE_ALIGNED_RESNEXT50,
                BASE_ALIGNED_SE_RESNET101, BASE_ALIGNED_DENSENET169,
                BASE_ALIGNED_MPNCOV_RESNET50, BASE_ALIGNED_MPNCOV_RESNET101,
                BASE_ALIGNED_MPNCOV_RESNEXT101
        ]:
            loss_dict = self.aligned_loss_fns(outputs, labels)
        elif self.base_type in [
                BASE_ALIGNED_RESNET50_ABD, BASE_ALIGNED_RESNET101_ABD,
                BASE_ALIGNED_RESNEXT101_ABD
        ]:
            loss_dict = self.aligned_loss_fns(outputs[:3], labels)
            if self.current_epoch >= self.cfg.MODEL.OF_START_EPOCH:  # 从第33个Epoch加of
                loss_dict['of_loss'] = self.of_penalty(outputs[3])
        elif self.base_type in [BASE_RESNET101_ABD, BASE_RESNEXT101_ABD]:
            loss_dict = self.loss_fns(outputs, labels)
            if self.current_epoch >= self.cfg.MODEL.OF_START_EPOCH:  # 从第33个Epoch加of
                loss_dict['of_loss'] = self.of_penalty(outputs[2])
        elif self.base_type in [MGN_RESNET50, MGN_RESNET101, MGN_RESNEXT101]:
            loss_dict = self.mgn_loss_fns(outputs, labels)
        else:
            loss_dict = self.loss_fns(outputs, labels)

        total_loss = 0
        print_str = f'\r Epoch {self.current_epoch} Iter {self.batch_nb}/{len(self.tng_dataloader)} '
        for loss_name, loss_value in loss_dict.items():
            total_loss += loss_value
            print_str += (loss_name + f': {loss_value.item():.3f} ')
        loss_dict['total_loss'] = total_loss.item()
        print_str += f'Total loss: {total_loss.item():.3f} '
        print(print_str, end=' ')

        if self.writer and (self.global_step + 1) % self.log_interval == 0:
            if 'ce_loss' in loss_dict.keys():
                self.writer.add_scalar('cross_entropy_loss',
                                       loss_dict['ce_loss'], self.global_step)
            self.writer.add_scalar('global_triplet_loss',
                                   loss_dict['global_triplet_loss'],
                                   self.global_step)
            if 'local_triplet_loss' in loss_dict.keys():
                self.writer.add_scalar('local_triplet_loss',
                                       loss_dict['local_triplet_loss'],
                                       self.global_step)
            self.writer.add_scalar('total_loss', loss_dict['total_loss'],
                                   self.global_step)

        self.running_loss.update(total_loss.item())
        if 'ce_loss' in loss_dict.keys():
            self.running_CE_loss.update(loss_dict['ce_loss'])
        self.running_GT_loss.update(loss_dict['global_triplet_loss'])
        if 'local_triplet_loss' in loss_dict.keys():
            self.running_LT_loss.update(loss_dict['local_triplet_loss'])
        if 'of_loss' in loss_dict.keys():
            self.running_OF_loss.update(loss_dict['of_loss'])

        self.opt.zero_grad()
        total_loss.backward()
        self.opt.step()

        self.global_step += 1
        self.batch_nb += 1

    def on_epoch_end(self):
        elapsed = time.time() - self.t0
        mins = int(elapsed) // 60
        seconds = int(elapsed - mins * 60)
        print('')
        self.logger.info(
            f'Epoch {self.current_epoch} Total loss: {self.running_loss.avg:.3f} CE loss: {self.running_CE_loss.avg:.3f} '
            f'GT loss: {self.running_GT_loss.avg:.3f} LT loss: {self.running_LT_loss.avg:.3f} OF loss: {self.running_OF_loss.avg:.3f} '
            f'lr: {self.opt.param_groups[0]["lr"]:.2e} During {mins:d}min:{seconds:d}s'
        )
        # update learning rate
        self.lr_sched.step()

    def test(self):
        # convert to eval mode
        self.model.eval()

        feats, pids, camids = [], [], []
        val_prefetcher = data_prefetcher(self.val_dataloader, self.cfg)
        batch = val_prefetcher.next()
        while batch[0] is not None:
            img, pid, camid = batch
            with torch.no_grad():
                feat = self.model(img)
            if isinstance(feat, tuple):
                feats.append(feat[0])
            else:
                feats.append(feat)

            pids.extend(pid.cpu().numpy())
            camids.extend(np.asarray(camid))

            batch = val_prefetcher.next()

        ####
        feats = torch.cat(feats, dim=0)
        if self.cfg.TEST.NORM:
            feats = F.normalize(feats, p=2, dim=1)

        # query
        qf = feats[:self.num_query]

        q_pids = np.asarray(pids[:self.num_query])
        q_camids = np.asarray(camids[:self.num_query])
        # gallery
        gf = feats[self.num_query:]

        g_pids = np.asarray(pids[self.num_query:])
        g_camids = np.asarray(camids[self.num_query:])

        # TODO: 添加rerank的测评结果
        # m, n = qf.shape[0], gf.shape[0]
        distmat = -torch.mm(qf, gf.t()).cpu().numpy()

        # distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
        #           torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
        # distmat.addmm_(1, -2, qf, gf.t())
        # distmat = distmat.numpy()
        cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)
        self.logger.info(f"Test Results - Epoch: {self.current_epoch}")
        self.logger.info(f"mAP: {mAP:.1%}")
        for r in [1, 5, 10]:
            self.logger.info(f"CMC curve, Rank-{r:<3}:{cmc[r - 1]:.1%}")

        self.writer.add_scalar('rank1', cmc[0], self.global_step)
        self.writer.add_scalar('mAP', mAP, self.global_step)
        metric_dict = {'rank1': cmc[0], 'mAP': mAP}
        # convert to train mode
        self.model.train()
        return metric_dict

    def train(self):
        self.on_train_begin()
        for epoch in range(self.start_epoch, self.max_epochs):
            self.on_epoch_begin()
            batch = self.tng_prefetcher.next()
            while batch[0] is not None:
                self.training_step(batch)
                batch = self.tng_prefetcher.next()
            self.on_epoch_end()
            if (epoch + 1) % self.eval_period == 0:
                metric_dict = self.test()
                if metric_dict['mAP'] > self.best_mAP:
                    is_best = True
                    self.best_mAP = metric_dict['mAP']
                else:
                    is_best = False

                # always save the last checkpoint as the best
                is_best = True
                self.save_checkpoint(is_best)

            torch.cuda.empty_cache()

    def save_checkpoint(self, is_best):
        if self.use_dp:
            state_dict = self.model.module.state_dict()
        else:
            state_dict = self.model.state_dict()

        # TODO: add optimizer state dict and lr scheduler
        filepath = os.path.join(self.model_save_dir,
                                f'model_epoch{self.current_epoch}.pth')
        torch.save(state_dict, filepath)
        # saving
        optpath = os.path.join(self.model_save_dir,
                               f'optimizer_epoch{self.current_epoch}.pth')
        opt_dict = {}
        opt_dict['optimizer'] = self.opt
        opt_dict['lr_scheduler'] = self.lr_sched
        opt_dict['epoch'] = self.current_epoch
        torch.save(opt_dict, optpath)

        if is_best:
            best_filepath = os.path.join(self.model_save_dir, 'model_best.pth')
            shutil.copyfile(filepath, best_filepath)

    def load_checkpoint(self, checkpoint_path, with_optimizer=True):
        ## load weights
        self.logger.info('Loading checkpoints from ' + checkpoint_path)
        state_dict = torch.load(checkpoint_path)

        # remove missing keys
        new_state_dict = state_dict.copy()
        for k in state_dict:
            if not k in self.model.state_dict():
                new_state_dict.pop(k)
                self.logger.info(f'Remove key {k} from checkpoint.')
        state_dict = new_state_dict

        if self.use_dp:
            self.model.module.load_state_dict(state_dict)
        else:
            self.model.load_state_dict(state_dict)

        ## load optimizer
        if with_optimizer:
            opt_path = checkpoint_path.replace('model_epoch',
                                               'optimizer_epoch')
            self.logger.info('Loading optimizer from ' + opt_path)
            opt_dict = torch.load(opt_path)
            self.opt = opt_dict['optimizer']
            self.lr_sched = opt_dict['lr_scheduler']
            self.start_epoch = opt_dict['epoch']
            self.current_epoch = opt_dict['epoch']
def validate(val_loader, model, criterion):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    if args.dump_dir is not None:
        QM().disable()
        DM(args.dump_dir)

    with torch.no_grad():
        end = time.time()
        for i, (input, target) in enumerate(val_loader):
            input = input.to(args.device)
            target = target.to(args.device)
            if args.dump_dir is not None and i == 5:
                with DM(args.dump_dir):
                    DM().set_tag('batch%d'%i)
                    # compute output
                    output = model(input)
                    break
            else:
                output = model(input)

            loss = criterion(output, target)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(float(prec1), input.size(0))
            top5.update(float(prec5), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                       i, len(val_loader), batch_time=batch_time, loss=losses,
                       top1=top1, top5=top5))

        print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))

    return losses.avg, top1.avg, top5.avg
Example #18
0
    def train(self, epoch, data_loader, optimizer1, optimizer2):
        self.model.train()

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        precisions = AverageMeter()
        precisions1 = AverageMeter()

        end = time.time()
        for i, inputs in enumerate(data_loader):
            data_time.update(time.time() - end)

            inputs, targets = self._parse_data(inputs)

            loss, prec_oim, prec_score = self._forward(inputs, targets)
            losses.update(loss.item(), targets.size(0))

            precisions.update(prec_oim, targets.size(0))
            precisions1.update(prec_score, targets.size(0))

            optimizer1.zero_grad()
            optimizer2.zero_grad()
            loss.backward()
            optimizer1.step()
            optimizer2.step()

            batch_time.update(time.time() - end)
            end = time.time()
            print_freq = 60
            num_step = len(data_loader)  # 1146
            num_iter = num_step * epoch + i
            self.writer.add_scalar('train/loss_step', losses.val, num_iter)
            self.writer.add_scalar('train/loss_avg', losses.avg, num_iter)
            self.writer.add_scalar('train/prec_pairloss', precisions1.avg, num_iter)
            self.writer.add_scalar('train/prec_oimloss', precisions.avg, num_iter)
            if (i + 1) % print_freq == 0:
                print('Epoch: [{}][{}/{}]\t'
                      'Loss {:.3f} ({:.3f})\t'
                      'prec_oim {:.2%} ({:.2%})\t'
                      'prec_score {:.2%} ({:.2%})\t'
                      .format(epoch, i + 1, len(data_loader),
                              losses.val, losses.avg,
                              precisions.val, precisions.avg,
                              precisions1.val, precisions1.avg))
def val_model(opt, val_loader, model, criterion, vis_tool, name1='1'):
    # begin to test the dataset
    model.eval()

    val_eval = ConfusionMeter(num_class=opt.out_dim)

    # meters=AverageMeterSet()
    # calculate the average values
    val_dice = AverageMeter()
    val_loss = AverageMeter()
    val_recall = AverageMeter()

    for batch_ids, (data, target) in enumerate(val_loader):
        if opt.use_cuda:
            data, target = data.cuda(), target.cuda()

            output = model(data)

            with torch.no_grad():
                loss = criterion(output, target)
                _, pred = torch.max(output, dim=1)

                val_loss.update(loss.item())
                val_eval.update(pred, target)

                avg_loss = val_loss.avg
                dice_value = val_eval.get_scores('Dice')
                recall_value = val_eval.get_scores('Recall')

                val_recall.update(recall_value)
                val_dice.update(dice_value)

                # begin to play
                if batch_ids % opt.val_plotfreq == 0:
                    vis_tool.plot('Val_Loss' + name1, loss.item())
                    vis_tool.plot('Val_Dice' + name1, dice_value)
                    vis_tool.plot('Val_Recall' + name1, recall_value)

                print(
                    'Val: Batch_Num:{}  Loss:{:.3f}  Dice:{:.3f}  Recall:{:.3f}'
                    .format(batch_ids, loss.item(), dice_value, recall_value))

    return avg_loss, val_dice.avg, val_recall.avg
Example #20
0
 def __init__(self, opt, net):
     self.opt = opt
     self.net = net
     self.loss = AverageMeter('loss')
     self.acc = AverageMeter('acc')
Example #21
0
    def forward(self,
                data_loader,
                num_steps=None,
                training=False,
                duplicates=1,
                average_output=False,
                chunk_batch=1,
                rec=False):
        if rec: output_embed = {}
        meters = {
            name: AverageMeter()
            for name in ['step', 'data', 'loss', 'prec1', 'prec5']
        }
        if training and self.grad_clip > 0:
            meters['grad'] = AverageMeter()

        batch_first = True
        if training and isinstance(self.model,
                                   nn.DataParallel) or chunk_batch > 1:
            batch_first = False
        if average_output:
            assert duplicates > 1 and batch_first, "duplicates must be > 1 for output averaging"

        def meter_results(meters):
            results = {name: meter.avg for name, meter in meters.items()}
            results['error1'] = 100. - results['prec1']
            results['error5'] = 100. - results['prec5']
            return results

        end = time.time()
        for i, (inputs, target) in (enumerate(data_loader)):
            if training and duplicates > 1 and self.adapt_grad_norm is not None \
                    and i % self.adapt_grad_norm == 0:
                grad_mean = 0
                num = inputs.size(1)
                for j in range(num):
                    grad_mean += float(
                        self._grad_norm(inputs.select(1, j), target))
                grad_mean /= num
                grad_all = float(
                    self._grad_norm(
                        *_flatten_duplicates(inputs, target, batch_first)))
                self.grad_scale = grad_mean / grad_all
                logging.info('New loss scale: %s', self.grad_scale)

            # measure data loading time
            meters['data'].update(time.time() - end)
            if duplicates > 1:  # multiple versions for each sample (dim 1)
                inputs, target = _flatten_duplicates(
                    inputs,
                    target,
                    batch_first,
                    expand_target=not average_output)

            output, loss, grad = self._step(inputs,
                                            target,
                                            training=training,
                                            average_output=average_output,
                                            chunk_batch=chunk_batch)
            if rec:
                with torch.no_grad():
                    for i in range(target.shape[0]):
                        tt = target[i]
                        emb = output[i]
                        output_embed[tt.tolist()] = emb
            if self.pruner is not None:
                with torch.no_grad():
                    if training:
                        compression_rate = self.pruner.calc_param_masks(
                            self.model, i % self.print_freq == 0,
                            i + self.epoch * len(data_loader))
                        if i % self.print_freq == 0:
                            logging.info('Total compression ratio is: ' +
                                         str(compression_rate))
                    self.model = self.pruner.prune_layers(self.model)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output, target, topk=(1, 5))
            meters['loss'].update(float(loss), inputs.size(0))
            meters['prec1'].update(float(prec1), inputs.size(0))
            meters['prec5'].update(float(prec5), inputs.size(0))
            if grad is not None:
                meters['grad'].update(float(grad), inputs.size(0))

            # measure elapsed time
            meters['step'].update(time.time() - end)
            end = time.time()

            if i % self.print_freq == 0:
                report = str(
                    '{phase} - Epoch: [{0}][{1}/{2}]\t'
                    'Time {meters[step].val:.3f} ({meters[step].avg:.3f})\t'
                    'Data {meters[data].val:.3f} ({meters[data].avg:.3f})\t'
                    'Loss {meters[loss].val:.7f} ({meters[loss].avg:.7f})\t'
                    'Prec@1 {meters[prec1].val:.6f} ({meters[prec1].avg:.6f})\t'
                    'Prec@5 {meters[prec5].val:.6f} ({meters[prec5].avg:.6f})\t'
                    .format(self.epoch,
                            i,
                            len(data_loader),
                            phase='TRAINING' if training else 'EVALUATING',
                            meters=meters))
                if 'grad' in meters.keys():
                    report += 'Grad {meters[grad].val:.3f} ({meters[grad].avg:.3f})'\
                        .format(meters=meters)
                logging.info(report)
            if num_steps is not None and i >= num_steps or (self.update_only_th
                                                            and training
                                                            and i > 2):
                break
        if self.pruner is not None:
            self.pruner.save_eps(epoch=self.epoch + 1)
            self.pruner.save_masks(epoch=self.epoch + 1)

        if rec: torch.save(output_embed, 'output_embed_calib')
        return meter_results(meters)
Example #22
0
    def train(self, epoch, data_loader):
        self.model.train()

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()

        start = time.time()
        for i, inputs in enumerate(data_loader):
            data_time.update(time.time() - start)

            # model optimizer
            self._parse_data(inputs)
            self._forward()
            self.optimizer.zero_grad()
            self._backward()
            self.optimizer.step()

            batch_time.update(time.time() - start)
            losses.update(self.loss.item())

            # tensorboard
            global_step = epoch * len(data_loader) + i
            self.summary_writer.add_scalar('loss', self.loss.item(),
                                           global_step)
            self.summary_writer.add_scalar(
                'lr', self.optimizer.param_groups[0]['lr'], global_step)

            start = time.time()

            if (i + 1) % self.opt.print_freq == 0:
                print('Epoch: [{}][{}/{}]\t'
                      'Batch Time {:.3f} ({:.3f})\t'
                      'Data Time {:.3f} ({:.3f})\t'
                      'Loss {:.3f} ({:.3f})\t'.format(epoch, i + 1,
                                                      len(data_loader),
                                                      batch_time.val,
                                                      batch_time.mean,
                                                      data_time.val,
                                                      data_time.mean,
                                                      losses.val, losses.mean))
        param_group = self.optimizer.param_groups
        print('Epoch: [{}]\tEpoch Time {:.3f} s\tLoss {:.3f}\t'
              'Lr {:.2e}'.format(epoch, batch_time.sum, losses.mean,
                                 param_group[0]['lr']))
        print()
    def eval(self, epoch):
        self.model.eval()
        losses = AverageMeter()
        correct = AverageMeter()
        prec1 = AverageMeter()
        prec2 = AverageMeter()
        prec5 = AverageMeter()
        with torch.no_grad():
            for step, (imgs, labels, orig_attrs) in enumerate(self.val_loader):
                imgs, labels = imgs.cuda(), labels.cuda()
                if self.with_attribute:
                    orig_attrs = orig_attrs.cuda()
                    attrs = orig_attrs.detach().clone()
                    attrs[attrs > self.xi] = 1.
                    attrs[attrs <= self.xi] = 0.
                    pred_id, pred_attrs = self.model(imgs, orig_attrs)
                    assert pred_attrs.shape[-1] == 134
                    loss = self.criterion[0](pred_id, labels)
                    loss_attrs = self.criterion[1](pred_attrs.float(), attrs.float())
                    if epoch > 15:
                        loss += loss_attrs
                else:
                    pred_id = self.model(imgs, orig_attrs)
                    loss = self.criterion(pred_id, labels)
                assert pred_id.shape[-1] == self.num_classes
                losses.update(loss.item(), labels.size(0))

                prec = accuracy(pred_id.data, labels.data, topk=(1, 2, 5),
                                is_multilabel=False)
                prec1.update(prec[0].item(), labels.size(0))
                prec2.update(prec[1].item(), labels.size(0))
                prec5.update(prec[2].item(), labels.size(0))

                y_pred = pred_id.argmax(dim=1)
                acc = (y_pred == labels).sum().item() / labels.size(0) * 100
                correct.update(acc, labels.size(0)/100.)

        print('Val: [{}] '
              'Loss {:.2f} ({:.2f})\t'
              'Acc {:.2f} ({:.2f})\t'
              'Prec1 {:.2%} ({:.2%})\t'
              'Prec2 {:.2%} ({:.2%})\t'
              'Prec5 {:.2%} ({:.2%})\t'
              .format(epoch,
                      losses.val, losses.avg,
                      correct.val, correct.avg,
                      prec1.val, prec1.avg,
                      prec2.val, prec2.avg,
                      prec5.val, prec5.avg
                      ))

        return correct.avg, losses.avg
Example #24
0
    def forward(self,
                data_loader,
                num_steps=None,
                training=False,
                duplicates=1):
        meters = {
            name: AverageMeter()
            for name in ['step', 'data', 'loss', 'prec1', 'prec5']
        }
        if training and self.grad_clip > 0:
            meters['grad'] = AverageMeter()

        def meter_results(meters):
            results = {name: meter.avg for name, meter in meters.items()}
            results['error1'] = 100. - results['prec1']
            results['error5'] = 100. - results['prec5']
            return results

        end = time.time()
        if training:
            self.delay_hist = defaultdict(int)
        for i, (inputs, target) in enumerate(data_loader):
            if training:
                self._schedule_worker(self.epoch * len(data_loader) + i)
            if training and tb.tboard.res_iterations:
                tb.tboard.update_step(self.epoch * len(data_loader) + i)
            # measure data loading time
            meters['data'].update(time.time() - end)
            target = target.to(self.device)
            inputs = inputs.to(self.device, dtype=self.dtype)

            if duplicates > 1:  # multiple versions for each sample (dim 1)
                target = target.view(-1, 1).expand(-1, inputs.size(1))
                inputs = inputs.flatten(0, 1)
                target = target.flatten(0, 1)

            output, loss, grad = self._step(inputs, target, training=training)
            # measure accuracy and record loss
            prec1, prec5 = accuracy(output.detach(), target, topk=(1, 5))
            meters['loss'].update(float(loss), inputs.size(0))
            meters['prec1'].update(float(prec1), inputs.size(0))
            meters['prec5'].update(float(prec5), inputs.size(0))
            if grad is not None:
                meters['grad'].update(float(grad), inputs.size(0))

            # measure elapsed time
            meters['step'].update(time.time() - end)
            if training and tb.tboard.res_iterations:
                tb.tboard.log_results(
                    training_loss_iter=float(loss),
                    training_error1_iter=100 - float(prec1),
                    iterations=self.epoch * len(data_loader) + i)
            end = time.time()

            if i % self.print_freq == 0:
                errors = {
                    'error1_val': 100 - meters['prec1'].val,
                    'error5_val': 100 - meters['prec5'].val,
                    'error1_avg': 100 - meters['prec1'].avg,
                    'error5_avg': 100 - meters['prec5'].avg
                }
                report = str(
                    '{phase} - Epoch: [{0}][{1}/{2}]\t'
                    'Time {meters[step].val:.3f} ({meters[step].avg:.3f})\t'
                    'Data {meters[data].val:.3f} ({meters[data].avg:.3f})\t'
                    'Loss {meters[loss].val:.4f} ({meters[loss].avg:.4f})\t'
                    'Error@1 {errors[error1_val]:.3f} ({errors[error1_avg]:.3f})\t'
                    'Error@5 {errors[error5_val]:.3f} ({errors[error5_avg]:.3f})\t'
                    .format(self.epoch,
                            i,
                            len(data_loader),
                            phase='TRAINING' if training else 'EVALUATING',
                            meters=meters,
                            errors=errors))
                if 'grad' in meters.keys():
                    report += 'Grad {meters[grad].val:.3f} ({meters[grad].avg:.3f})' \
                        .format(meters=meters)
                logging.info(report)

            if num_steps is not None and i >= num_steps:
                break

        return meter_results(meters)
    def train(self, epoch):
        self.model.train()
        correct = AverageMeter()
        losses = AverageMeter()
        prec1 = AverageMeter()
        prec2 = AverageMeter()
        prec5 = AverageMeter()

        for step, (imgs, labels, orig_attrs) in enumerate(self.train_loader):

            imgs, labels = imgs.cuda(), labels.cuda()
            pred_attrs = []
            if self.with_attribute:
                orig_attrs = orig_attrs.cuda()
                attrs = orig_attrs.detach().clone()
                attrs[attrs > self.xi] = 1.
                attrs[attrs <= self.xi] = 0.
                pred_id, pred_attrs = self.model(imgs, orig_attrs)
                assert pred_attrs.shape[-1] == 134
            else:
                pred_id = self.model(imgs, orig_attrs)
            assert pred_id.shape[-1] == self.num_classes

            if self.with_attribute:
                loss = self.criterion[0](pred_id, labels)
                loss_attrs = self.criterion[1](pred_attrs.float(), attrs.float())
                if epoch > 15:
                    loss += loss_attrs
            else:
                loss = self.criterion(pred_id, labels)

            #clip_grad_norm_(self.model.parameters(), max_norm=10.0)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            prec = accuracy(pred_id.data, labels.data, topk=(1, 2, 5))

            losses.update(loss.item(), labels.size(0))
            prec1.update(prec[0].item(), labels.size(0))
            prec2.update(prec[1].item(), labels.size(0))
            prec5.update(prec[2].item(), labels.size(0))
            y_pred = pred_id.argmax(dim=1)
            acc = (y_pred == labels).sum().item() / labels.size(0) * 100
            correct.update(acc, labels.size(0)/100.)

            # tensorboard
            if self.summary_writer is not None:
                global_step = epoch * len(self.train_loader) + step
                self.summary_writer.add_scalar('train_loss', loss.item(), global_step)
                self.summary_writer.add_scalar('train_acc', 1. * correct.avg, global_step)
                self.summary_writer.add_scalar('prec1', prec1.avg, global_step)
                self.summary_writer.add_scalar('prec2', prec2.avg, global_step)
                self.summary_writer.add_scalar('prec5', prec5.avg, global_step)

            if (step + 1) % 10 == 0:
                print('[{}] '
                      'Loss {:.3f} ({:.3f})\t'
                      'Acc {:.2f} ({:.2f})\t'
                      'Prec1 {:.2%} ({:.2%})\t'
                      'Prec2 {:.2%} ({:.2%})\t'
                      'Prec5 {:.2%} ({:.2%})\t'
                      .format(step + 1,
                              losses.val, losses.avg,
                              correct.val, correct.avg,
                              prec1.val, prec2.avg,
                              prec2.val, prec2.avg,
                              prec5.val, prec5.avg
                              ))

        return correct.avg, losses.avg
Example #26
0
def forward(data_loader,
            model,
            criterion,
            epoch=0,
            training=True,
            optimizer=None):
    regularizer = getattr(model, 'regularization', None)
    if args.device_ids and len(args.device_ids) > 1:
        model = torch.nn.DataParallel(model, args.device_ids)

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    end = time.time()
    for i, (inputs, target) in enumerate(data_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        target = target.to(args.device)
        inputs = inputs.to(args.device, dtype=dtype)

        # compute output
        output = model(inputs)
        loss = criterion(output, target)
        if regularizer is not None:
            loss += regularizer(model)

        if type(output) is list:
            output = output[0]

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.detach(), target, topk=(1, 5))
        losses.update(float(loss), inputs.size(0))
        top1.update(float(prec1), inputs.size(0))
        top5.update(float(prec5), inputs.size(0))

        if training:
            optimizer.update(epoch, epoch * len(data_loader) + i)
            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            logging.info('{phase} - Epoch: [{0}][{1}/{2}]\t'
                         'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                         'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                         'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                         'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                         'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                             epoch,
                             i,
                             len(data_loader),
                             phase='TRAINING' if training else 'EVALUATING',
                             batch_time=batch_time,
                             data_time=data_time,
                             loss=losses,
                             top1=top1,
                             top5=top5))

    return losses.avg, top1.avg, top5.avg
class Trainer(BaseTrainer):
    def __init__(self, cfg, network, optimizer, loss, lr_scheduler, device,
                 trainloader, testloader, writer):
        super(Trainer,
              self).__init__(cfg, network, optimizer, loss, lr_scheduler,
                             device, trainloader, testloader, writer)
        self.network = self.network.to(device)
        self.train_loss_metric = AverageMeter(writer=writer,
                                              name='Loss/train',
                                              length=len(self.trainloader))
        self.train_acc_metric = AverageMeter(writer=writer,
                                             name='Accuracy/train',
                                             length=len(self.trainloader))

        self.val_loss_metric = AverageMeter(writer=writer,
                                            name='Loss/val',
                                            length=len(self.testloader))
        self.val_acc_metric = AverageMeter(writer=writer,
                                           name='Accuracy/val',
                                           length=len(self.testloader))
        self.best_val_acc = 0

    def load_model(self):
        saved_name = os.path.join(
            self.cfg['output_dir'],
            '{}_{}.pth'.format(self.cfg['model']['base'],
                               self.cfg['dataset']['name']))
        state = torch.load(saved_name)

        self.optimizer.load_state_dict(state['optimizer'])
        self.network.load_state_dict(state['state_dict'])

    def save_model(self, epoch):
        if not os.path.exists(self.cfg['output_dir']):
            os.makedirs(self.cfg['output_dir'])

        saved_name = os.path.join(
            self.cfg['output_dir'],
            '{}_{}.pth'.format(self.cfg['model']['base'],
                               self.cfg['dataset']['name']))

        state = {
            'epoch': epoch,
            'state_dict': self.network.state_dict(),
            'optimizer': self.optimizer.state_dict()
        }

        torch.save(state, saved_name)

    def train_one_epoch(self, epoch):

        self.network.train()
        self.train_loss_metric.reset(epoch)
        self.train_acc_metric.reset(epoch)

        for i, (img, mask, label) in enumerate(self.trainloader):
            img, mask, label = img.to(self.device), mask.to(
                self.device), label.to(self.device)
            net_mask, net_label = self.network(img)
            self.optimizer.zero_grad()
            loss = self.loss(net_mask, net_label, mask, label)
            loss.backward()
            self.optimizer.step()

            # Calculate predictions
            preds = predict(net_mask,
                            net_label,
                            score_type=self.cfg['test']['score_type'])
            targets = predict(mask,
                              label,
                              score_type=self.cfg['test']['score_type'])
            acc = calc_acc(preds, targets)
            # Update metrics
            self.train_loss_metric.update(loss.item())
            self.train_acc_metric.update(acc)

            print('Epoch: {}, iter: {}, loss: {}, acc: {}'.format(
                epoch,
                epoch * len(self.trainloader) + i, self.train_loss_metric.avg,
                self.train_acc_metric.avg))

    def train(self):

        for epoch in range(self.cfg['train']['num_epochs']):
            self.train_one_epoch(epoch)
            epoch_acc = self.validate(epoch)
            if epoch_acc > self.best_val_acc:
                self.best_val_acc = epoch_acc
                self.save_model(epoch)

    def validate(self, epoch):
        self.network.eval()
        self.val_loss_metric.reset(epoch)
        self.val_acc_metric.reset(epoch)

        seed = randint(0, len(self.testloader) - 1)

        for i, (img, mask, label) in enumerate(self.testloader):
            img, mask, label = img.to(self.device), mask.to(
                self.device), label.to(self.device)
            net_mask, net_label = self.network(img)
            loss = self.loss(net_mask, net_label, mask, label)

            # Calculate predictions
            preds = predict(net_mask,
                            net_label,
                            score_type=self.cfg['test']['score_type'])
            targets = predict(mask,
                              label,
                              score_type=self.cfg['test']['score_type'])
            acc = calc_acc(preds, targets)
            # Update metrics
            self.val_loss_metric.update(loss.item())
            self.val_acc_metric.update(acc)

            if i == seed:
                add_images_tb(self.cfg, epoch, img, preds, targets,
                              self.writer)

        return self.val_acc_metric.avg
Example #28
0
def train(epoch, train_loader, model, criterion, optimizers, summary_writer):
    global center_criterion
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    if not os.path.exists(cfg.TRAIN.SNAPSHOT_DIR):
        os.makedirs(cfg.TRAIN.SNAPSHOT_DIR)

    # start training
    model.train()
    start = time.time()
    for ii, datas in enumerate(train_loader):
        data_time.update(time.time() - start)
        img, bag_id, cam_id = datas
        if cfg.CUDA:
            img = img.cuda()
            bag_id = bag_id.cuda()

        triplet_features, softmax_features = model(img)

        for optimizer in optimizers:
            optimizer.zero_grad()

        loss = criterion(softmax_features, triplet_features, bag_id)
        loss.backward()

        for param in center_criterion.parameters():
            param.grad.data *= (1. / cfg.TRAIN.CENTER_LOSS_WEIGHT)

        for optimizer in optimizers:
            optimizer.step()

        batch_time.update(time.time() - start)
        losses.update(loss.item())
        # tensorboard
        if summary_writer:
            global_step = epoch * len(train_loader) + ii
            summary_writer.add_scalar('loss', loss.item(), global_step)

        start = time.time()

        if (ii + 1) % cfg.TRAIN.PRINT_FREQ == 0:
            logger.info('Epoch: [{}][{}/{}]\t'
                        'Batch Time {:.3f} ({:.3f})\t'
                        'Data Time {:.3f} ({:.3f})\t'
                        'Loss {:.3f} ({:.3f}) \t'.format(
                            epoch + 1, ii + 1, len(train_loader),
                            batch_time.val, batch_time.mean, data_time.val,
                            data_time.mean, losses.val, losses.mean))
    adam_param_groups = optimizers[0].param_groups
    logger.info('Epoch: [{}]\tEpoch Time {:.3f} s\tLoss {:.3f}\t'
                'Adam Lr {:.2e} \t '.format(epoch + 1, batch_time.sum,
                                            losses.mean,
                                            adam_param_groups[0]['lr']))
Example #29
0
class ReidSystem():
    def __init__(self, cfg, logger, writer):
        self.cfg, self.logger, self.writer = cfg, logger, writer
        # Define dataloader
        self.tng_dataloader, self.val_dataloader, self.num_classes, self.num_query = get_dataloader(
            cfg)
        # networks
        self.model = build_model(cfg, self.num_classes)
        # loss function
        self.ce_loss = nn.CrossEntropyLoss()
        self.triplet = TripletLoss(cfg.SOLVER.MARGIN)
        # optimizer and scheduler
        self.opt = make_optimizer(self.cfg, self.model)
        self.lr_sched = make_lr_scheduler(self.cfg, self.opt)

        self._construct()

    def _construct(self):
        self.global_step = 0
        self.current_epoch = 0
        self.batch_nb = 0
        self.max_epochs = self.cfg.SOLVER.MAX_EPOCHS
        self.log_interval = self.cfg.SOLVER.LOG_INTERVAL
        self.eval_period = self.cfg.SOLVER.EVAL_PERIOD
        self.use_dp = False
        self.use_ddp = False

    def loss_fns(self, outputs, labels):
        ce_loss = self.ce_loss(outputs[0], labels)
        triplet_loss = self.triplet(outputs[1], labels)[0]

        return {'ce_loss': ce_loss, 'triplet': triplet_loss}

    def on_train_begin(self):
        self.best_mAP = -np.inf
        self.running_loss = AverageMeter()
        log_save_dir = os.path.join(self.cfg.OUTPUT_DIR,
                                    self.cfg.DATASETS.TEST_NAMES,
                                    self.cfg.MODEL.VERSION)
        self.model_save_dir = os.path.join(log_save_dir, 'ckpts')
        if not os.path.exists(self.model_save_dir):
            os.makedirs(self.model_save_dir)

        self.gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
        self.use_dp = (len(self.gpus) > 0) and (self.cfg.MODEL.DIST_BACKEND
                                                == 'dp')

        if self.use_dp:
            self.model = nn.DataParallel(self.model)

        self.model = self.model.cuda()

        self.model.train()

    def on_epoch_begin(self):
        self.batch_nb = 0
        self.current_epoch += 1
        self.t0 = time.time()
        self.running_loss.reset()

        self.tng_prefetcher = data_prefetcher(self.tng_dataloader)

    def training_step(self, batch):
        inputs, labels, _ = batch
        outputs = self.model(inputs, labels)
        loss_dict = self.loss_fns(outputs, labels)

        total_loss = 0
        print_str = f'\r Epoch {self.current_epoch} Iter {self.batch_nb}/{len(self.tng_dataloader)} '
        for loss_name, loss_value in loss_dict.items():
            total_loss += loss_value
            print_str += (loss_name + f': {loss_value.item():.3f} ')
        loss_dict['total_loss'] = total_loss.item()
        print_str += f'Total loss: {total_loss.item():.3f} '
        print(print_str, end=' ')

        if (self.global_step + 1) % self.log_interval == 0:
            self.writer.add_scalar('cross_entropy_loss', loss_dict['ce_loss'],
                                   self.global_step)
            self.writer.add_scalar('triplet_loss', loss_dict['triplet'],
                                   self.global_step)
            self.writer.add_scalar('total_loss', loss_dict['total_loss'],
                                   self.global_step)

        self.running_loss.update(total_loss.item())

        self.opt.zero_grad()
        total_loss.backward()
        self.opt.step()

        self.global_step += 1
        self.batch_nb += 1

    def on_epoch_end(self):
        elapsed = time.time() - self.t0
        mins = int(elapsed) // 60
        seconds = int(elapsed - mins * 60)
        print('')
        self.logger.info(
            f'Epoch {self.current_epoch} Total loss: {self.running_loss.avg:.3f} '
            f'lr: {self.opt.param_groups[0]["lr"]:.2e} During {mins:d}min:{seconds:d}s'
        )
        # update learning rate
        self.lr_sched.step()

    def test(self):
        # convert to eval mode
        self.model.eval()

        feats, pids, camids = [], [], []
        val_prefetcher = data_prefetcher(self.val_dataloader)
        batch = val_prefetcher.next()
        while batch[0] is not None:
            img, pid, camid = batch
            with torch.no_grad():
                feat = self.model(img)
            feats.append(feat)
            pids.extend(pid.cpu().numpy())
            camids.extend(np.asarray(camid))

            batch = val_prefetcher.next()

        feats = torch.cat(feats, dim=0)
        if self.cfg.TEST.NORM:
            feats = F.normalize(feats, p=2, dim=1)
        # query
        qf = feats[:self.num_query]
        q_pids = np.asarray(pids[:self.num_query])
        q_camids = np.asarray(camids[:self.num_query])
        # gallery
        gf = feats[self.num_query:]
        g_pids = np.asarray(pids[self.num_query:])
        g_camids = np.asarray(camids[self.num_query:])

        # m, n = qf.shape[0], gf.shape[0]
        distmat = torch.mm(qf, gf.t()).cpu().numpy()
        # distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
        #           torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
        # distmat.addmm_(1, -2, qf, gf.t())
        # distmat = distmat.numpy()
        cmc, mAP = evaluate(-distmat, q_pids, g_pids, q_camids, g_camids)
        self.logger.info(f"Test Results - Epoch: {self.current_epoch}")
        self.logger.info(f"mAP: {mAP:.1%}")
        for r in [1, 5, 10]:
            self.logger.info(f"CMC curve, Rank-{r:<3}:{cmc[r - 1]:.1%}")

        self.writer.add_scalar('rank1', cmc[0], self.global_step)
        self.writer.add_scalar('mAP', mAP, self.global_step)
        metric_dict = {'rank1': cmc[0], 'mAP': mAP}
        # convert to train mode
        self.model.train()
        return metric_dict

    def train(self):
        self.on_train_begin()
        for epoch in range(self.max_epochs):
            self.on_epoch_begin()
            batch = self.tng_prefetcher.next()
            while batch[0] is not None:
                self.training_step(batch)
                batch = self.tng_prefetcher.next()
            self.on_epoch_end()
            if (epoch + 1) % self.eval_period == 0:
                metric_dict = self.test()
                if metric_dict['mAP'] > self.best_mAP:
                    is_best = True
                    self.best_mAP = metric_dict['mAP']
                else:
                    is_best = False
                self.save_checkpoints(is_best)

            torch.cuda.empty_cache()

    def save_checkpoints(self, is_best):
        if self.use_dp:
            state_dict = self.model.module.state_dict()
        else:
            state_dict = self.model.state_dict()

        # TODO: add optimizer state dict and lr scheduler
        filepath = os.path.join(self.model_save_dir,
                                f'model_epoch{self.current_epoch}.pth')
        torch.save(state_dict, filepath)
        if is_best:
            best_filepath = os.path.join(self.model_save_dir, 'model_best.pth')
            shutil.copyfile(filepath, best_filepath)
Example #30
0
def train(train_loader, model, criterion, optimizer, epoch, device):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    topk = [AverageMeter() for i in range(4)]

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        target = target.float()
        if input.dim() > 4:
            input = input.reshape(input.shape[0] * input.shape[1],
                                  input.shape[2], input.shape[3],
                                  input.shape[4])
            #target  = target.float()
            target = target.reshape(target.shape[0] * target.shape[1],
                                    target.shape[2])
            #target = torch.from_numpy(target).float()

        input = input.to(device)
        target = target.to(device)

        # compute output
        output = model(input)
        output = output.cpu()
        target = target.cpu()
        loss = criterion(output, target)

        # measure accuracy and record loss
        prec = accuracy(output, target, topk=4)
        for k in range(4):
            topk[k].update(prec[k], input.size(0))
        losses.update(loss.item(), input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\n'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@2 {top2.val:.3f} ({top2.avg:.3f})\t'
                  'Prec@3 {top3.val:.3f} ({top3.avg:.3f})\t'
                  'Prec@4 {top4.val:.3f} ({top4.avg:.3f})\t'.format(
                      epoch,
                      i,
                      len(train_loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      top1=topk[0],
                      top2=topk[1],
                      top3=topk[2],
                      top4=topk[3]))
Example #31
0
    def forward(self,
                data_loader,
                num_steps=None,
                training=False,
                average_output=False,
                chunk_batch=1):

        meters = {
            name: AverageMeter()
            for name in ['step', 'data', 'loss', 'prec1', 'prec5']
        }
        if training and self.grad_clip > 0:
            meters['grad'] = AverageMeter()

        batch_first = True
        if training and isinstance(self.model,
                                   nn.DataParallel) or chunk_batch > 1:
            batch_first = False

        def meter_results(meters):
            results = {name: meter.avg for name, meter in meters.items()}
            results['error1'] = 100. - results['prec1']
            results['error5'] = 100. - results['prec5']
            return results

        end = time.time()

        for i, (inputs, target) in enumerate(data_loader):
            duplicates = inputs.dim() > 4  # B x D x C x H x W
            if training and duplicates and self.adapt_grad_norm is not None \
                    and i % self.adapt_grad_norm == 0:
                grad_mean = 0
                num = inputs.size(1)
                for j in range(num):
                    grad_mean += float(
                        self._grad_norm(inputs.select(1, j), target))
                grad_mean /= num
                grad_all = float(
                    self._grad_norm(
                        *_flatten_duplicates(inputs, target, batch_first)))
                self.grad_scale = grad_mean / grad_all
                logging.info('New loss scale: %s', self.grad_scale)

            # measure data loading time
            meters['data'].update(time.time() - end)
            if duplicates:  # multiple versions for each sample (dim 1)
                inputs, target = _flatten_duplicates(
                    inputs,
                    target,
                    batch_first,
                    expand_target=not average_output)

            output, loss, grad = self._step(inputs,
                                            target,
                                            training=training,
                                            average_output=average_output,
                                            chunk_batch=chunk_batch)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output, target, topk=(1, 5))
            meters['loss'].update(float(loss), inputs.size(0))
            meters['prec1'].update(float(prec1), inputs.size(0))
            meters['prec5'].update(float(prec5), inputs.size(0))
            if grad is not None:
                meters['grad'].update(float(grad), inputs.size(0))

            # measure elapsed time
            meters['step'].update(time.time() - end)
            end = time.time()

            if i % self.print_freq == 0 or i == len(data_loader) - 1:
                report = str(
                    '{phase} - Epoch: [{0}][{1}/{2}]\t'
                    'Time {meters[step].val:.3f} ({meters[step].avg:.3f})\t'
                    'Data {meters[data].val:.3f} ({meters[data].avg:.3f})\t'
                    'Loss {meters[loss].val:.4f} ({meters[loss].avg:.4f})\t'
                    'Prec@1 {meters[prec1].val:.3f} ({meters[prec1].avg:.3f})\t'
                    'Prec@5 {meters[prec5].val:.3f} ({meters[prec5].avg:.3f})\t'
                    .format(self.epoch,
                            i,
                            len(data_loader),
                            phase='TRAINING' if training else 'EVALUATING',
                            meters=meters))
                if 'grad' in meters.keys():
                    report += 'Grad {meters[grad].val:.3f} ({meters[grad].avg:.3f})'\
                        .format(meters=meters)
                logging.info(report)
                self.observe(trainer=self,
                             model=self._model,
                             optimizer=self.optimizer,
                             data=(inputs, target))
                self.stream_meters(meters,
                                   prefix='train' if training else 'eval')
                if training:
                    self.write_stream(
                        'lr',
                        (self.training_steps, self.optimizer.get_lr()[0]))

            if num_steps is not None and i >= num_steps:
                break

        return meter_results(meters)
Example #32
0
    def train(self, epoch, data_loader, optimizer1, optimizer2):
        self.model.train()

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        precisions = AverageMeter()
        precisions1 = AverageMeter()

        end = time.time()
        for i, inputs in enumerate(data_loader):
            data_time.update(time.time() - end)

            inputs, targets = self._parse_data(inputs)

            loss, prec_oim, prec_score = self._forward(inputs, targets)

            losses.update(loss.item(), targets.size(0))

            precisions.update(prec_oim, targets.size(0))
            precisions1.update(prec_score, targets.size(0))

            optimizer1.zero_grad()
            optimizer2.zero_grad()
            loss.backward()
            optimizer1.step()
            optimizer2.step()

            batch_time.update(time.time() - end)
            end = time.time()
            print_freq = 50
            if (i + 1) % print_freq == 0:
                print('Epoch: [{}][{}/{}]\t'
                      'Loss {:.3f} ({:.3f})\t'
                      'prec_oim {:.2%} ({:.2%})\t'
                      'prec_score {:.2%} ({:.2%})\t'
                      .format(epoch, i + 1, len(data_loader),
                              losses.val, losses.avg,
                              precisions.val, precisions.avg,
                              precisions1.val, precisions1.avg))
Example #33
0
    def evaluate(self, query_loader, gallery_loader, queryinfo, galleryinfo):

        self.cnn_model.eval()
        self.att_model.eval()
        self.classifier_model.eval()

        querypid = queryinfo.pid
        querycamid = queryinfo.camid
        querytranum = queryinfo.tranum

        gallerypid = galleryinfo.pid
        gallerycamid = galleryinfo.camid
        gallerytranum = galleryinfo.tranum

        pooled_probe, hidden_probe = self.extract_feature(query_loader)

        querylen = len(querypid)
        gallerylen = len(gallerypid)

        # online gallery extraction
        single_distmat = np.zeros((querylen, gallerylen))
        gallery_resize = 0
        gallery_popindex = 0
        gallery_popsize = gallerytranum[gallery_popindex]

        gallery_resfeatures = 0
        gallery_resraw = 0

        gallery_empty = True
        preimgs = 0
        preflows = 0

        # time
        gallery_time = AverageMeter()
        end = time.time()

        for i, (imgs, flows, _, _) in enumerate(gallery_loader):
            imgs = to_torch(imgs)
            flows = to_torch(flows)
            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
            imgs = imgs.to(device)
            flows = flows.to(device)

            with torch.no_grad():
                seqnum = imgs.size(0)
                if i == 0:
                    preimgs = imgs
                    preflows = flows

                if gallery_empty:
                    out_feat, out_raw = self.cnn_model(imgs, flows, self.mode)

                    gallery_resfeatures = out_feat
                    gallery_resraw = out_raw

                    gallery_empty = False

                elif imgs.size(0) < gallery_loader.batch_size:
                    flaw_batchsize = imgs.size(0)
                    cat_batchsize = gallery_loader.batch_size - flaw_batchsize
                    imgs = torch.cat((imgs, preimgs[0:cat_batchsize]), 0)
                    flows = torch.cat((flows, preflows[0:cat_batchsize]), 0)
                    out_feat, out_raw = self.cnn_model(imgs, flows, self.mode)

                    out_feat = out_feat[0:flaw_batchsize]
                    out_raw  = out_raw[0:flaw_batchsize]

                    gallery_resfeatures = torch.cat((gallery_resfeatures, out_feat), 0)
                    gallery_resraw = torch.cat((gallery_resraw, out_raw), 0)

                else:
                    out_feat, out_raw = self.cnn_model(imgs, flows, self.mode)

                    gallery_resfeatures = torch.cat((gallery_resfeatures, out_feat), 0)
                    gallery_resraw = torch.cat((gallery_resraw, out_raw), 0)

            gallery_resize = gallery_resize + seqnum

            while gallery_popsize <= gallery_resize:

                if (gallery_popindex + 1) % 50 == 0:
                    print('gallery--{:04d}'.format(gallery_popindex))
                gallery_popfeatures = gallery_resfeatures[0:gallery_popsize, :]
                gallery_popraw = gallery_resraw[0:gallery_popsize, :]

                if gallery_popsize < gallery_resize:
                    gallery_resfeatures = gallery_resfeatures[gallery_popsize:gallery_resize, :]
                    gallery_resraw = gallery_resraw[gallery_popsize:gallery_resize, :]
                else:
                    gallery_resfeatures = 0
                    gallery_resraw = 0
                    gallery_empty = True

                gallery_resize = gallery_resize - gallery_popsize

                pooled_gallery, pooled_raw = self.att_model.selfpooling_model(gallery_popfeatures, gallery_popraw)
                probesize = pooled_probe.size()
                gallerysize = pooled_gallery.size()
                probe_batch = probesize[0]
                gallery_batch = gallerysize[0]
                gallery_num = gallerysize[1]
                pooled_gallery.unsqueeze(0)
                pooled_gallery = pooled_gallery.expand(probe_batch, gallery_batch, gallery_num)

                encode_scores = self.classifier_model(pooled_probe, pooled_gallery)

                encode_size = encode_scores.size()
                encodemat = encode_scores.view(-1, 2)
                encodemat = F.softmax(encodemat)
                encodemat = encodemat.view(encode_size[0], encode_size[1], 2)
                distmat_qall_g = encodemat[:, :, 0]

                q_start = 0
                for qind, qnum in enumerate(querytranum):
                    distmat_qg = distmat_qall_g[q_start:q_start + qnum, :]
                    distmat_qg = distmat_qg.data.cpu().numpy()
                    percile = np.percentile(distmat_qg, 20)

                    if distmat_qg[distmat_qg <= percile] is not None:
                        distmean = np.mean(distmat_qg[distmat_qg <= percile])
                    else:
                        distmean = np.mean(distmat_qg)

                    single_distmat[qind, gallery_popindex] = distmean
                    q_start = q_start + qnum

                gallery_popindex = gallery_popindex + 1

                if gallery_popindex < gallerylen:

                    gallery_popsize = gallerytranum[gallery_popindex]
                gallery_time.update(time.time() - end)
                end = time.time()

        return evaluate_seq(single_distmat, querypid, querycamid, gallerypid, gallerycamid)