def train(self, epoch):
        print 'Starting training'
        self.model.train()
        accum = defaultdict(float)
        #loss_loc = WHD_losses.WeightedHausdorffDistance(resized_height=self.opts['p_num'],resized_width=self.opts['p_num'],
        #                                                return_2_terms=True,
        #                                                device=device)
        #loss_loc = WHD_losses.AveragedHausdorffLoss()
        loss_loc = losses.WeightedHausdorffDistance(resized_height=224,
                                                    resized_width=224,
                                                    return_2_terms=True,
                                                    device=device)
        # focalloss = focal.FocalLoss(None,None,None,'mean')
        focalloss = focal.FocalLoss()
        # To accumulate stats for printing
        for step, data in enumerate(self.train_loader):
            if len(data['img']) == 1:
                continue
            if self.opts['get_point_annotation']:
                img = data['img'].to(device)
                annotation = data['annotation_prior'].to(device).unsqueeze(1)
                img = torch.cat([img, annotation], 1)
            else:
                img = data['img'].to(device)
            self.optimizer.zero_grad()
            if self.global_step % self.opts['val_freq'] == 0 and not self.opts[
                    'debug']:
                self.validate()
                self.save_checkpoint(epoch)
            output = self.model.forward(img, data['fwd_poly'])
            loss_sum = 0
            pred_cps = output['pred_polys'][-1]
            pred_polys = self.spline.sample_point(pred_cps)
            # print(pred_polys.shape)
            # print(output['vertex_logits'].shape)
            gt_right_order, poly_mathcing_loss_sum = losses.poly_mathcing_loss(
                self.opts['p_num'],
                pred_polys,
                data['gt_poly'].to(device),
                loss_type=self.opts['loss_type'])
            # add by dzh contour refine
            ## Initializing Contour Box
            level_set_config_dict = {
                'step_ckpts': [50],
                'lambda_': 0.0,
                'alpha': 1,
                'smoothing': 1,
                'render_radius': -1,
                'is_gt_semantic': True,
                'method': 'MLS',
                'balloon': 1,
                'threshold': 0.99,
                'merge_weight': 0.5
            }
            cbox = ContourBox.LevelSetAlignment(n_workers=1,
                                                fn_post_process_callback=None,
                                                config=level_set_config_dict)
            # print('-------------shape--------------------')
            output_contour, _ = cbox(
                {
                    'seg': np.expand_dims(data['edge_mask'], 0),
                    'bdry': None
                },
                np.expand_dims(
                    output['edge_logits'].view(
                        data['edge_mask'].shape).cpu().detach().numpy(), 0))
            masks_step = output_contour[0, :, 0, :, :]
            #--------add by dzh 7.18
            edge_annotation_loss = 0
            curr_fp_edge_loss = losses.fp_edge_loss(
                torch.from_numpy(masks_step).to(
                    device
                ),  #self.opts['fp_weight'] * losses.fp_edge_loss(torch.from_numpy(masks_step).to(device),
                output['edge_logits']
            )  #data['edge_mask'] torch.from_numpy(masks_step)
            edge_annotation_loss += curr_fp_edge_loss
            tt = []
            #pred_poly_mask = np.zeros((36, 36), np.float32)
            for i in range(pred_polys.shape[0]):
                pred_poly_mask = np.zeros((224, 224), dtype=np.float32)
                ff = np.floor(pred_polys[i].detach().cpu().numpy() *
                              36).astype(np.int32)

                if not isinstance(ff, list):
                    ff = [ff]
                for p in ff:
                    pred_poly_mask = utils.draw_poly(pred_poly_mask, p)
                #ff=utils.poly01_to_poly0g(pred_polys[i].detach().cpu().numpy(), 35)
                # pred_poly_mask = utils.get_vertices_mask_36(ff, pred_poly_mask)
                tt.append(pred_poly_mask)
            tt1 = np.array(tt, dtype=np.float32)
            pred_poly_mask11 = torch.from_numpy(tt1).cuda()
            ll1 = pred_poly_mask11
            #ll1 = output['vertex_logits'].view(output['vertex_logits'].shape[0],28,28)
            jjj = []
            for tt in range(ll1.shape[0]):
                jjj.append([224, 224])
            #jjj = [[28,28],[28,28],[28,28],[28,28],[28,28],[28,28],[28,28],[28,28],[28,28],[28,28],[28,28],[28,28],[28,28],[28,28],[28,28],[28,28]]
            # print(data['poly_mask'].shape)
            kk = []
            poly_mask_ori = data['poly_mask']
            for hh in range(ll1.shape[0]):
                zzz = torch.FloatTensor(poly_mask_ori[hh].astype(
                    np.float32)).cuda()
                #zzz = torch.FloatTensor(data['gt_orig_poly'][hh]).cuda()
                kk.append(zzz)
            #zzz = torch.from_numpy(data['gt_orig_poly'][0])
            # print(ll1.shape)
            # print(kk.shape)
            #ll1,kk
            term1, term2 = loss_loc.forward(
                ll1, kk,
                torch.FloatTensor(np.array(jjj, dtype=np.float32)).cuda())
            #fp_vertex_loss = self.opts['fp_weight'] * (term1+term2)
            fp_vertex_loss = 0.1 * (term1 + term2) + poly_mathcing_loss_sum
            #fp_vertex_loss = poly_mathcing_loss_sum + self.opts['fp_weight']* 0.1 * (term1+term2)
            loss_sum += fp_vertex_loss
            loss_sum += edge_annotation_loss  # + self.opts['fp_weight'] * (term1+term2)
            ################iou loss function#################
            #preds= pred_polys.detach().data.cpu().numpy()
            #iou_loss = 0
            #orig_poly = data['orig_poly']

            #for i in range(preds.shape[0]):
            #    curr_pred_poly = np.floor(preds[i] * 224).astype(np.int32)
            #    curr_gt_poly = np.floor(orig_poly[i] * 224).astype(np.int32)
            #    cur_iou, masks = metrics.iou_from_poly(np.array(curr_pred_poly, dtype=np.int32),
            #                                                    np.array(curr_gt_poly, dtype=np.int32),
            #                                                    224, 224)
            #    iou_loss += cur_iou
            #iou_loss = -iou_loss / preds.shape[0]
            #loss_sum += 0.1 * iou_loss
            ################iou loss function#################
            with torch.no_grad():
                iou = 0
                gt_mask_0 = []
                pred_mask_0 = []
                orig_poly = data['orig_poly']
                preds = pred_polys.detach().data.cpu().numpy()
                # iou_filter = []
                for i in range(preds.shape[0]):
                    curr_pred_poly = np.floor(preds[i] * 224).astype(np.int32)
                    curr_gt_poly = np.floor(orig_poly[i] * 224).astype(
                        np.int32)
                    cur_iou, masks = metrics.iou_from_poly(
                        np.array(curr_pred_poly, dtype=np.int32),
                        np.array(curr_gt_poly, dtype=np.int32), 224, 224)
                    gt_mask_0.append(masks[1])
                    pred_mask_0.append(masks[0])
            gt_mask_1 = torch.from_numpy(
                np.array(gt_mask_0)).to(device).float()
            pred_mask_1 = torch.from_numpy(
                np.array(pred_mask_0)).to(device).float()
            # mask_loss = focalloss(pred_mask_1, gt_mask_1)
            # mask_loss = losses.class_balanced_cross_entropy_loss(pred_mask_1, gt_mask_1)
            # pred111=pred_mask_1.view(pred_mask_1.shape[0],1,224,224)
            #mask_loss = 100 * focalloss((pred_mask_1/255), (gt_mask_1/255))
            mask_loss = torch.sum(
                torch.abs(gt_mask_1 / 250 - pred_mask_1 / 250))
            loss_sum += torch.mean(mask_loss)
            #         # iou_filter.append(1 if cur_iou>self.opts['iou_filter'] else 0)
            #         iou += cur_iou
            # iou = iou / preds.shape[0]
            # # iou_filter = np.array(iou_filter)
            # # iou_filter = torch.from_numpy(iou_filter).to(device).float()

            # loss_sum += (-iou)
            # if self.opts['iou_filter']>0:
            #     loss_sum = (loss_sum + (1-iou)) * iou_filter

            # loss_sum = torch.mean(loss_sum)
            loss_sum.backward()
            if 'grad_clip' in self.opts.keys():
                nn.utils.clip_grad_norm_(self.model.parameters(),
                                         self.opts['grad_clip'])
            self.optimizer.step()
            preds = pred_polys.detach().data.cpu().numpy()
            with torch.no_grad():
                # Get IoU
                iou = 0
                orig_poly = data['orig_poly']

                for i in range(preds.shape[0]):
                    curr_pred_poly = np.floor(preds[i] * 224).astype(np.int32)
                    curr_gt_poly = np.floor(orig_poly[i] * 224).astype(
                        np.int32)

                    cur_iou, masks = metrics.iou_from_poly(
                        np.array(curr_pred_poly, dtype=np.int32),
                        np.array(curr_gt_poly, dtype=np.int32), 224, 224)
                    iou += cur_iou
                iou = iou / preds.shape[0]
                accum['loss'] += float(loss_sum.item())
                accum['iou'] += iou
                accum['length'] += 1
                if self.opts['edge_loss']:
                    accum['edge_annotation_loss'] += float(
                        edge_annotation_loss.item())
                print(
                    "[%s] Epoch: %d, Step: %d, Polygon Loss: %f,  IOU: %f" \
                    % (str(datetime.now()), epoch, self.global_step, accum['loss'] / accum['length'], accum['iou'] / accum['length']))
                if step % self.opts['print_freq'] == 0:
                    # Mean of accumulated values
                    for k in accum.keys():
                        if k == 'length':
                            continue
                        accum[k] /= accum['length']

                    # Add summaries
                    masks = np.expand_dims(masks, -1).astype(
                        np.uint8)  # Add a channel dimension
                    #print(masks.shape)
                    masks = np.tile(masks, [1, 1, 1, 3])  # Make [2, H, W, 3]
                    img = (data['img'].cpu().numpy()[-1, ...] * 255).astype(
                        np.uint8)
                    img = np.transpose(img, [1, 2, 0])  # Make [H, W, 3]
                    self.writer.add_image('pred_mask', masks[0],
                                          self.global_step)
                    self.writer.add_image('gt_mask', masks[1],
                                          self.global_step)
                    self.writer.add_image('image', img, self.global_step)
                    self.writer.add_image(
                        'edge_acm_gt',
                        np.tile(
                            np.expand_dims(masks_step[preds.shape[0] - 1],
                                           axis=-1).astype(np.uint8),
                            [1, 1, 3]), self.global_step)
                    #self.writer.add_image('ori_GT',
                    pred_edge_mask = np.tile(
                        np.expand_dims(
                            output['edge_logits'].cpu().numpy()[preds.shape[0]
                                                                - 1] * 255,
                            axis=-1).astype(np.uint8),
                        [1, 1, 3]).reshape(28, 28, 3)
                    #print(pred_edge_mask.shape)
                    self.writer.add_image('pred_edge', pred_edge_mask,
                                          self.global_step)
                    for k in accum.keys():
                        if k == 'length':
                            continue
                        self.writer.add_scalar(k, accum[k], self.global_step)
                    print(
                    "[%s] Epoch: %d, Step: %d, Polygon Loss: %f,  IOU: %f" \
                    % (str(datetime.now()), epoch, self.global_step, accum['loss'], accum['iou']))

                    accum = defaultdict(float)

            del (output, masks, pred_polys, preds, loss_sum)
            self.global_step += 1
    def train(self, epoch):
        print 'Starting training'

        self.model.train()

        accum = defaultdict(float)
        # To accumulate stats for printing

        for step, data in enumerate(self.train_loader):

            if len(data['img']) == 1:
                continue

            if self.opts['get_point_annotation']:
                img = data['img'].to(device)
                annotation = data['annotation_prior'].to(device).unsqueeze(1)

                img = torch.cat([img, annotation], 1)
            else:
                img = data['img'].to(device)

            self.optimizer.zero_grad()
            if self.global_step % self.opts['val_freq'] == 0 and not self.opts[
                    'debug']:

                #self.validate()
                self.save_checkpoint(epoch)

            output = self.model(img, data['fwd_poly'],
                                data['sampled_interactive'])
            output_prob = F.softmax(output['x_prob'], dim=1)
            loss_sum = 0
            pred_cps = output['pred_polys'][-1]

            pred_polys = self.spline.sample_point(pred_cps)  #(bs, p_num, 2)
            #pred_polys = pred_cps
            # loss 1
            gt_right_order, poly_mathcing_loss_sum = losses.poly_mathcing_loss(
                self.opts['p_num'],
                pred_polys,
                data['gt_poly'].to(device),
                loss_type=self.opts['loss_type'])
            loss_sum += poly_mathcing_loss_sum

            edge_annotation_loss = 0
            # loss 2
            curr_fp_edge_loss = self.opts['fp_weight'] * losses.fp_edge_loss(
                data['edge_mask'].to(device), output['edge_logits'])
            edge_annotation_loss += curr_fp_edge_loss
            # loss 3
            fp_vertex_loss = self.opts['fp_weight'] * losses.fp_vertex_loss(
                data['vertex_mask'].to(device), output['vertex_logits'])
            # loss 4 , Dice loss, Boundary loss

            #            dice_loss = losses.GeneralizedDice(output_prob, data['onehot_label'].cuda())
            boundary_loss = losses.SurfaceLoss(output_prob,
                                               data['mask_distmap'].cuda())

            edge_annotation_loss += fp_vertex_loss

            loss_sum += edge_annotation_loss
            #loss_sum += dice_loss
            #loss_sum += boundary_loss
            loss_sum = loss_sum + boundary_loss
            loss_sum.backward()

            if 'grad_clip' in self.opts.keys():
                nn.utils.clip_grad_norm_(self.model.parameters(),
                                         self.opts['grad_clip'])

            self.optimizer.step()

            preds = pred_polys.detach().data.cpu().numpy()
            with torch.no_grad():
                # Get IoU
                iou = 0
                orig_poly = data['orig_poly']

                for i in range(preds.shape[0]):
                    curr_pred_poly = np.floor(preds[i] * 224).astype(np.int32)
                    curr_gt_poly = np.floor(orig_poly[i] * 224).astype(
                        np.int32)

                    cur_iou, masks = metrics.iou_from_poly(
                        np.array(curr_pred_poly, dtype=np.int32),
                        np.array(curr_gt_poly, dtype=np.int32), 224, 224)
                    iou += cur_iou
                iou = iou / preds.shape[0]
                accum['loss'] += float(loss_sum.item())
                accum['iou'] += iou
                #accum['dice'] += float(dice_loss)
                accum['boundary'] += float(boundary_loss)
                accum['length'] += 1
                if self.opts['edge_loss']:
                    accum['edge_annotation_loss'] += float(
                        edge_annotation_loss.item())
                print(
                    "[%s] Epoch: %d, Step: %d, Polygon Loss: %f,  IOU: %f, Boundary: %f" \
                    % (str(datetime.now()), epoch, self.global_step, accum['loss'] / accum['length'], accum['iou'] / accum['length'], accum['boundary']))
                if step % self.opts['print_freq'] == 0:
                    # Mean of accumulated values
                    for k in accum.keys():
                        if k == 'length':
                            continue
                        accum[k] /= accum['length']

                    # Add summaries
                    masks = np.expand_dims(masks, -1).astype(
                        np.uint8)  # Add a channel dimension
                    masks = np.tile(masks, [1, 1, 1, 3])  # Make [2, H, W, 3]
                    img = (data['img'].cpu().numpy()[-1, ...] * 255).astype(
                        np.uint8)
                    img = np.transpose(img, [1, 2, 0])  # Make [H, W, 3]

                    self.writer.add_image('pred_mask', masks[0],
                                          self.global_step)
                    self.writer.add_image('gt_mask', masks[1],
                                          self.global_step)
                    self.writer.add_image('image', img, self.global_step)

                    for k in accum.keys():
                        if k == 'length':
                            continue
                        self.writer.add_scalar(k, accum[k], self.global_step)

                    print(
                    "[%s] Epoch: %d, Step: %d, Polygon Loss: %f,  IOU: %f" \
                    % (str(datetime.now()), epoch, self.global_step, accum['loss'], accum['iou']))

                    accum = defaultdict(float)

            del (output, masks, pred_polys, preds, loss_sum)
            self.global_step += 1
Beispiel #3
0
    def train(self, epoch):
        print('Starting training')
        self.model.train()

        accum = defaultdict(float)
        # To accumulate stats for printin
        for step, data in enumerate(self.train_loader):
            
            if self.global_step % self.opts['val_freq'] == 0:
                self.validate()
                self.save_checkpoint(epoch)             

            # Forward pass
            output = self.model(data['img'].to(device), data['fwd_poly'].to(device))
                
                # Smoothed targets
            dt_targets = utils.dt_targets_from_class(output['poly_class'].cpu().numpy(),
                                                         self.grid_size, self.opts['dt_threshold'])

            # Get losses
            loss = losses.poly_vertex_loss_mle(torch.from_numpy(dt_targets).to(device), 
                                                   data['mask'].to(device), output['logits'])
            fp_edge_loss = self.opts['fp_weight'] * losses.fp_edge_loss(data['edge_mask'].to(device), 
                                        output['edge_logits'])
            fp_vertex_loss = self.opts['fp_weight'] * losses.fp_vertex_loss(data['vertex_mask'].to(device), 
                                          output['vertex_logits'])

            total_loss = loss + fp_edge_loss + fp_vertex_loss

            # Backward pass
            self.optimizer.zero_grad()
            total_loss.backward()
            
            if 'grad_clip' in self.opts.keys():
                nn.utils.clip_grad_norm_(self.model.parameters(), self.opts['grad_clip']) 

            self.optimizer.step()

            # Get accuracy
            accuracy = metrics.train_accuracy(output['poly_class'].cpu().numpy(), data['mask'].cpu().numpy(), 
            output['pred_polys'].cpu().numpy(), self.grid_size)

            # Get IoU
            iou = 0
            pred_polys = output['pred_polys'].cpu().numpy()
            gt_polys = data['full_poly']

            for i in range(pred_polys.shape[0]):
                p = pred_polys[i]
                p = utils.get_masked_poly(p, self.grid_size)
                p = utils.class_to_xy(p, self.grid_size)
                i, masks = metrics.iou_from_poly(p, gt_polys[i], self.grid_size, self.grid_size)
                iou += i

            iou = iou / pred_polys.shape[0]

            accum['loss'] += float(loss)
            accum['fp_edge_loss'] += float(fp_edge_loss)
            accum['fp_vertex_loss'] += float(fp_vertex_loss)
            accum['accuracy'] += accuracy
            accum['iou'] += iou
            accum['length'] += 1
                
            if step % self.opts['print_freq'] == 0:
                    # Mean of accumulated values
                for k in accum.keys():
                    if k == 'length':
                        continue
                    accum[k] /= accum['length']

                # Add summaries
                masks = np.expand_dims(masks, -1).astype(np.uint8) # Add a channel dimension
                masks = np.tile(masks, [1, 1, 1, 3]) # Make [2, H, W, 3]
                img = (data['img'].cpu().numpy()[-1,...]*255).astype(np.uint8)
                img = np.transpose(img, [1,2,0]) # Make [H, W, 3]
                vert_logits = np.reshape(output['vertex_logits'][-1, ...].detach().cpu().numpy(), (self.grid_size, self.grid_size, 1))
                edge_logits = np.reshape(output['edge_logits'][-1, ...].detach().cpu().numpy(), (self.grid_size, self.grid_size, 1))
                vert_logits = (1/(1 + np.exp(-vert_logits))*255).astype(np.uint8)
                edge_logits = (1/(1 + np.exp(-edge_logits))*255).astype(np.uint8)
                vert_logits = np.tile(vert_logits, [1, 1, 3]) # Make [H, W, 3]
                edge_logits = np.tile(edge_logits, [1, 1, 3]) # Make [H, W, 3]
                vertex_mask = np.tile(np.expand_dims(data['vertex_mask'][-1,...].cpu().numpy().astype(np.uint8)*255,-1),(1,1,3))
                edge_mask = np.tile(np.expand_dims(data['edge_mask'][-1,...].cpu().numpy().astype(np.uint8)*255,-1),(1,1,3))

                self.writer.add_image('pred_mask', masks[0], self.global_step)
                self.writer.add_image('gt_mask', masks[1], self.global_step)
                self.writer.add_image('image', img, self.global_step)
                self.writer.add_image('vertex_logits', vert_logits, self.global_step)
                self.writer.add_image('edge_logits', edge_logits, self.global_step)
                self.writer.add_image('edge_mask', edge_mask, self.global_step)
                self.writer.add_image('vertex_mask', vertex_mask, self.global_step)
                    
                if self.opts['return_attention'] is True:
                    att = output['attention'][-1, 1:4, ...].detach().cpu().numpy()
                    att = np.transpose(att, [0, 2, 3, 1]) # Make [T, H, W, 1]
                    att = np.tile(att, [1, 1, 1, 3]) # Make [T, H, W, 3]
                    def _scale(att):
                        att = att/np.max(att)
                        return (att*255).astype(np.int32)
                    self.writer.add_image('attention_1', pyramid_expand(_scale(att[0]), upscale=8, sigma=10), self.global_step)
                    self.writer.add_image('attention_2', pyramid_expand(_scale(att[1]), upscale=8, sigma=10), self.global_step)
                    self.writer.add_image('attention_3', pyramid_expand(_scale(att[2]), upscale=8, sigma=10), self.global_step)
                    
                for k in accum.keys():
                    if k == 'length':
                        continue
                    self.writer.add_scalar(k, accum[k], self.global_step)
                print("[%s] Epoch: %d, Step: %d, Polygon Loss: %f, Edge Loss: %f, Vertex Loss: %f, Accuracy: %f, IOU: %f"\
                    %(str(datetime.now()), epoch, self.global_step, accum['loss'], accum['fp_edge_loss'], accum['fp_vertex_loss'],\
                      accum['accuracy'], accum['iou']))
                
                accum = defaultdict(float)

            del(output)
            self.global_step += 1