def forward(self, data, target):
        target, target_embed = target
        target = target.numpy()

        # reshape target into binary seenmask
        seen = [x for x in range(self.n_class) if x not in self.unseen]
        target = np.in1d(target.ravel(), seen).reshape(target.shape).astype(int)

        target = torch.from_numpy(target)

        if self.cuda:
             data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)

        score = self.model(data, mode='seenmask')
        loss = utils.cross_entropy2d(score, target, size_average=True)

        lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :]
        lbl_true = target.data.cpu()

        return score, loss, lbl_pred, lbl_true
示例#2
0
    def forward(self, data, target):
        #  get score
        if self.pixel_embeddings:
            target, target_embed = target

        if self.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)

        if self.pixel_embeddings:
            if self.cuda:
                target_embed = target_embed.cuda()
            target_embed = Variable(target_embed)

        score = self.model(data, mode='fcn')

        # get loss
        if self.loss_func == "cos":
            loss = utils.cosine_loss(score, target, target_embed)
        elif self.loss_func == "mse":
            loss = utils.mse_loss(score, target, target_embed)
        elif self.loss_func == "cross_entropy":
            loss = utils.cross_entropy2d(score, target, size_average=False)

        if np.isnan(float(loss.data[0])):
            raise ValueError('loss is nan while training')

        # inference
        if self.pixel_embeddings:
            if self.forced_unseen:
                lbl_pred = utils.infer_lbl_forced_unseen(
                    score, target, self.seen_embeddings,
                    self.unseen_embeddings, self.unseen, self.cuda)
            else:
                lbl_pred = utils.infer_lbl(score, self.embeddings, self.cuda)
        else:
            lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :]
        lbl_true = target.data.cpu()

        return score, loss, lbl_pred, lbl_true
示例#3
0
    def validate(self):
        """
        Function to validate a training model on the val split.
        """
        
        self.model.eval()
        self.netG.eval()

        val_loss = 0
        num_vis = 8
        visualizations = []
        generations = []
        label_trues, label_preds = [], []
        
        # Evaluation
        for batch_idx, (data, target) in tqdm.tqdm(
            enumerate(self.val_loader), total=len(self.val_loader),
            desc='Validation iteration = %d' % self.iteration):

            if self.cuda:
                data, target = data.cuda(), target.cuda()
            data, target = Variable(data, volatile=True), Variable(target)
            
            score, fc7, pool4, pool3 = self.model(data)
            outG = self.netG(fc7, pool4, pool3)

            loss = cross_entropy2d(score, target, size_average=self.size_average)
            if np.isnan(float(loss.data[0])):
                raise ValueError('loss is nan while validating')
            val_loss += float(loss.data[0]) / len(data)

            imgs = data.data.cpu()
            lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :]
            lbl_true = target.data.cpu()
            
            # Visualizing predicted labels
            for img, lt, lp , outG_ in zip(imgs, lbl_true, lbl_pred,outG):
                
                outG_ = outG_*255.0
                outG_ = outG_.data.cpu().numpy().squeeze().transpose((1,2,0))[:,:,::-1].astype(np.uint8)
                img = self.val_loader.dataset.untransform(img.numpy())
                lt = lt.numpy()
                label_trues.append(lt)
                label_preds.append(lp)
                if len(visualizations) < num_vis:
                    lt[lt >= CLASS_NUM] = -1# to make fcn.utils.visualize_segmentation work!
                    viz = fcn.utils.visualize_segmentation(
                        lbl_pred=lp, lbl_true=lt, img=img, n_class=self.n_class)
                    visualizations.append(viz)
                    generations.append(outG_)
        
        # Computing the metrics
        metrics = torchfcn.utils.label_accuracy_score(
            label_trues, label_preds, self.n_class)
        val_loss /= len(self.val_loader)

        # Saving the label visualizations and generations
        out = osp.join(self.out, 'visualization_viz')
        if not osp.exists(out):
            os.makedirs(out)
        out_file = osp.join(out, 'iter%012d_labelmap.jpg' % self.iteration)
        scipy.misc.imsave(out_file, fcn.utils.get_tile_image(visualizations))
        out_file = osp.join(out, 'iter%012d_generations.jpg' % self.iteration)
        scipy.misc.imsave(out_file, fcn.utils.get_tile_image(generations))

        # Logging
        logger.info("validation mIoU: {}".format(metrics[2]))
        with open(osp.join(self.out, 'log.csv'), 'a') as f:
         elapsed_time = \
             datetime.datetime.now(pytz.timezone('Asia/Tokyo')) - \
             self.timestamp_start
         log = [self.epoch, self.iteration] + [''] * 5 + \
               [val_loss] + list(metrics) + [elapsed_time]
         log = map(str, log)
         f.write(','.join(log) + '\n')

        # Saving the models
        mean_iu = metrics[2]
        is_best = mean_iu > self.best_mean_iu
        if is_best:
            self.best_mean_iu = mean_iu
        torch.save({
         'epoch': self.epoch,
         'iteration': self.iteration,
         'arch': self.model.__class__.__name__,
         'optim_state_dict': self.optim.state_dict(),
         'model_state_dict': self.model.state_dict(),
         'best_mean_iu': self.best_mean_iu,
        }, osp.join(self.out, 'checkpoint.pth.tar'))
        if is_best:
            shutil.copy(osp.join(self.out, 'checkpoint.pth.tar'),
                     osp.join(self.out, 'model_best.pth.tar'))
示例#4
0
    def train_epoch(self):
        """
        Function to train the model for one epoch
        """
        self.model.train()
        self.netG.train()
        self.netD.train()

        for batch_idx, (datas, datat) in tqdm.tqdm(
            enumerate(itertools.izip(self.train_loader, self.target_loader)), total=min(len(self.target_loader), len(self.train_loader)),
            desc='Train epoch = {}/{}'.format(self.epoch, self.max_epoch)):

            data_source, labels_source = datas
            data_target, __ = datat
            data_source_forD = torch.zeros((data_source.size()[0], 3, self.image_size_forD[1], self.image_size_forD[0]))            
            data_target_forD = torch.zeros((data_target.size()[0], 3, self.image_size_forD[1], self.image_size_forD[0]))
            
            # We pass the unnormalized data to the discriminator. So, the GANs produce images without data normalization
            for i in range(data_source.size()[0]):
                data_source_forD[i] = self.train_loader.dataset.transform_forD(data_source[i], self.image_size_forD, resize=False, mean_add=True)
                data_target_forD[i] = self.train_loader.dataset.transform_forD(data_target[i], self.image_size_forD, resize=False, mean_add=True)

            iteration = batch_idx + self.epoch * min(len(self.train_loader), len(self.target_loader))
            self.iteration = iteration

            if self.cuda:
                data_source, labels_source = data_source.cuda(), labels_source.cuda()
                data_target = data_target.cuda()
                data_source_forD = data_source_forD.cuda()
                data_target_forD = data_target_forD.cuda()
            
            data_source, labels_source = Variable(data_source), Variable(labels_source)
            data_target = Variable(data_target)
            data_source_forD = Variable(data_source_forD)
            data_target_forD = Variable(data_target_forD)



            # Source domain 
            score, fc7, pool4, pool3 = self.model(data_source)
            outG_src = self.netG(fc7, pool4, pool3)
            outD_src_fake_s, outD_src_fake_c = self.netD(outG_src)
            outD_src_real_s, outD_src_real_c = self.netD(data_source_forD)
            
            # target domain
            tscore, tfc7, tpool4, tpool3= self.model(data_target)
            outG_tgt = self.netG(tfc7, tpool4, tpool3)
            outD_tgt_real_s, outD_tgt_real_c = self.netD(data_target_forD)
            outD_tgt_fake_s, outD_tgt_fake_c = self.netD(outG_tgt)

            # Creating labels for D. We need two sets of labels since our model is a ACGAN style framework.
            # (1) Labels for the classsifier branch. This will be a downsampled version of original segmentation labels
            # (2) Domain lables for classifying source real, source fake, target real and target fake
            
            # Labels for classifier branch 
            Dout_sz = outD_src_real_s.size()
            label_forD = torch.zeros((outD_tgt_fake_c.size()[0], outD_tgt_fake_c.size()[2], outD_tgt_fake_c.size()[3]))#[1,40,80]
            for i in range(label_forD.size()[0]):
                label_forD[i] = self.train_loader.dataset.transform_label_forD(labels_source[i], (outD_tgt_fake_c.size()[2], outD_tgt_fake_c.size()[3]))
            if self.cuda:
                label_forD = label_forD.cuda()
            label_forD = Variable(label_forD.long())

            # Domain labels
            domain_labels_src_real = torch.LongTensor(Dout_sz[0],Dout_sz[2],Dout_sz[3]).zero_()
            domain_labels_src_fake = torch.LongTensor(Dout_sz[0],Dout_sz[2],Dout_sz[3]).zero_()+1
            domain_labels_tgt_real = torch.LongTensor(Dout_sz[0],Dout_sz[2],Dout_sz[3]).zero_()+2
            domain_labels_tgt_fake = torch.LongTensor(Dout_sz[0],Dout_sz[2],Dout_sz[3]).zero_()+3

            domain_labels_src_real = Variable(domain_labels_src_real.cuda())
            domain_labels_src_fake = Variable(domain_labels_src_fake.cuda())
            domain_labels_tgt_real = Variable(domain_labels_tgt_real.cuda())
            domain_labels_tgt_fake = Variable(domain_labels_tgt_fake.cuda())

            
            # Updates.
            # There are three sets of updates - (1) Discriminator, (2) Generator and (3) F network
            
            # (1) Discriminator updates
            lossD_src_real_s = cross_entropy2d(outD_src_real_s, domain_labels_src_real, size_average=self.size_average)
            lossD_src_fake_s = cross_entropy2d(outD_src_fake_s, domain_labels_src_fake, size_average=self.size_average)
            lossD_src_real_c = cross_entropy2d(outD_src_real_c, label_forD, size_average=self.size_average)#TODO,buggy
            lossD_tgt_real = cross_entropy2d(outD_tgt_real_s, domain_labels_tgt_real, size_average=self.size_average)
            lossD_tgt_fake = cross_entropy2d(outD_tgt_fake_s, domain_labels_tgt_fake, size_average=self.size_average)           
            
            self.optimD.zero_grad()            
            lossD = lossD_src_real_s + lossD_src_fake_s + lossD_src_real_c + lossD_tgt_real + lossD_tgt_fake
            lossD /= len(data_source)
            lossD.backward(retain_graph=True)
            self.optimD.step()
        
            
            # (2) Generator updates
            self.optimG.zero_grad()            
            lossG_src_adv_s = cross_entropy2d(outD_src_fake_s, domain_labels_src_real,size_average=self.size_average)
            lossG_src_adv_c = cross_entropy2d(outD_src_fake_c, label_forD,size_average=self.size_average)
            lossG_tgt_adv_s = cross_entropy2d(outD_tgt_fake_s, domain_labels_tgt_real,size_average=self.size_average)
            lossG_src_mse = F.l1_loss(outG_src,data_source_forD)
            lossG_tgt_mse = F.l1_loss(outG_tgt,data_target_forD)

            lossG = lossG_src_adv_c + 0.1*(lossG_src_adv_s+ lossG_tgt_adv_s) + self.l1_weight * (lossG_src_mse + lossG_tgt_mse)
            lossG /= len(data_source)
            lossG.backward(retain_graph=True)
            self.optimG.step()

            # (3) F network updates 
            self.optim.zero_grad()            
            lossC = cross_entropy2d(score, labels_source,size_average=self.size_average)
            lossF_src_adv_s = cross_entropy2d(outD_src_fake_s, domain_labels_tgt_real,size_average=self.size_average)
            lossF_tgt_adv_s = cross_entropy2d(outD_tgt_fake_s, domain_labels_src_real,size_average=self.size_average)
            lossF_src_adv_c = cross_entropy2d(outD_src_fake_c, label_forD,size_average=self.size_average)
            
            lossF = lossC + self.adv_weight*(lossF_src_adv_s + lossF_tgt_adv_s) + self.c_weight*lossF_src_adv_c
            lossF /= len(data_source)
            lossF.backward()
            self.optim.step()
            
            if np.isnan(float(lossD.data[0])):
                raise ValueError('lossD is nan while training')
            if np.isnan(float(lossG.data[0])):
                raise ValueError('lossG is nan while training')
            if np.isnan(float(lossF.data[0])):
                raise ValueError('lossF is nan while training')
           
            # Computing metrics for logging
            metrics = []
            lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :]
            lbl_true = labels_source.data.cpu().numpy()
            for lt, lp in zip(lbl_true, lbl_pred):
                acc, acc_cls, mean_iu, fwavacc = \
                    torchfcn.utils.label_accuracy_score(
                        [lt], [lp], n_class=self.n_class)
                metrics.append((acc, acc_cls, mean_iu, fwavacc))
            metrics = np.mean(metrics, axis=0)

            # Logging
            if self.iteration%100 == 0:
                logger.info("epoch: {}/{}, iteration:{}, lossF:{}, mIoU :{}".format(self.epoch, self.max_epoch, self.iteration,lossF.data[0], metrics[2]))
            with open(osp.join(self.out, 'log.csv'), 'a') as f:
                elapsed_time = (
                    datetime.datetime.now(pytz.timezone('Asia/Tokyo')) -
                    self.timestamp_start).total_seconds()
                log = [self.epoch, self.iteration] + [lossF.data[0]] + \
                    metrics.tolist() + [''] * 5 + [elapsed_time]
                log = map(str, log)
                f.write(','.join(log) + '\n')

            if self.iteration >= self.max_iter:
                break
            
            # Validating periodically
            if self.iteration % self.interval_validate == 0 and self.iteration > 0:
                out_recon = osp.join(self.out, 'visualization_viz')
                if not osp.exists(out_recon):
                    os.makedirs(out_recon)
                generations = []

                # Saving generated source and target images
                source_img = self.val_loader.dataset.untransform(data_source.data.cpu().numpy().squeeze())
                target_img = self.val_loader.dataset.untransform(data_target.data.cpu().numpy().squeeze())
                outG_src_ = (outG_src)*255.0
                outG_tgt_ = (outG_tgt)*255.0
                outG_src_ = outG_src_.data.cpu().numpy().squeeze().transpose((1,2,0))[:,:,::-1].astype(np.uint8)
                outG_tgt_ = outG_tgt_.data.cpu().numpy().squeeze().transpose((1,2,0))[:,:,::-1].astype(np.uint8)

                generations.append(source_img)
                generations.append(outG_src_)
                generations.append(target_img)
                generations.append(outG_tgt_)
                out_file = osp.join(out_recon, 'iter%012d_src_target_recon.png' % self.iteration)
                scipy.misc.imsave(out_file, fcn.utils.get_tile_image(generations))

                # Validation
                self.validate()
                self.model.train()
                self.netG.train()
示例#5
0
def train(args, out, net_name):
    data_path = get_data_path(args.dataset)
    data_loader = get_loader(args.dataset)
    loader = data_loader(data_path, is_transform=True)
    n_classes = loader.n_classes
    print(n_classes)
    kwargs = {'num_workers': 8, 'pin_memory': True}

    trainloader = data.DataLoader(loader,
                                  batch_size=args.batch_size,
                                  shuffle=True)

    another_loader = data_loader(data_path, split='val', is_transform=True)

    valloader = data.DataLoader(another_loader,
                                batch_size=args.batch_size,
                                shuffle=True)

    # compute weight for cross_entropy2d
    norm_hist = hist / np.max(hist)
    weight = 1 / np.log(norm_hist + 1.02)
    weight[-1] = 0
    weight = torch.FloatTensor(weight)
    model = Bilinear_Res(n_classes)

    if torch.cuda.is_available():
        model.cuda(0)
        weight = weight.cuda(0)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr_rate,
                                 weight_decay=args.w_decay)
    # optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr_rate)
    scheduler = StepLR(optimizer, step_size=100, gamma=args.lr_decay)

    for epoch in tqdm.tqdm(range(args.epochs),
                           desc='Training',
                           ncols=80,
                           leave=False):
        scheduler.step()
        model.train()
        loss_list = []
        file = open(out + '/{}_epoch_{}.txt'.format(net_name, epoch), 'w')
        for i, (images, labels) in tqdm.tqdm(enumerate(trainloader),
                                             total=len(trainloader),
                                             desc='Iteration',
                                             ncols=80,
                                             leave=False):
            if torch.cuda.is_available():
                images = Variable(images.cuda(0))
                labels = Variable(labels.cuda(0))
            else:
                images = Variable(images)
                labels = Variable(labels)
            optimizer.zero_grad()
            outputs = model(images)
            loss = cross_entropy2d(outputs, labels, weight=weight)
            loss_list.append(loss.data[0])
            loss.backward()
            optimizer.step()

        # file.write(str(np.average(loss_list)))
        print(np.average(loss_list))
        file.write(str(np.average(loss_list)) + '\n')
        model.eval()
        gts, preds = [], []
        if (epoch % 10 == 0):
            for i, (images, labels) in tqdm.tqdm(enumerate(valloader),
                                                 total=len(valloader),
                                                 desc='Valid Iteration',
                                                 ncols=80,
                                                 leave=False):
                if torch.cuda.is_available():
                    images = Variable(images.cuda(0))
                    labels = Variable(labels.cuda(0))
                else:
                    images = Variable(images)
                    labels = Variable(labels)
                outputs = model(images)
                pred = outputs.data.max(1)[1].cpu().numpy()
                gt = labels.data.cpu().numpy()
                for gt_, pred_ in zip(gt, pred):
                    gts.append(gt_)
                    preds.append(pred_)
            score, class_iou = scores(gts, preds, n_class=n_classes)
            for k, v in score.items():
                file.write('{} {}\n'.format(k, v))

            for i in range(n_classes):
                file.write('{} {}\n'.format(i, class_iou[i]))
            torch.save(
                model.state_dict(),
                out + "/{}_{}_{}.pkl".format(net_name, args.dataset, epoch))
        file.close()
示例#6
0
def main():

  #########  configs ###########
  best_metric = 0

  pretrain_deeplab_path = os.path.join(configs.py_dir, 'model/deeplab_coco.pth')

  ######  load datasets ########
  train_transform_det = trans.Compose([
      trans.Scale((321, 321)),
  ])
  val_transform_det = trans.Compose([
      trans.Scale((321,321)),

  ])

  train_data = voc_dates.VOCDataset(configs.train_img_dir,configs.train_label_dir,
                                    configs.train_txt_dir,'train',transform=True,
                                    transform_med = train_transform_det)
  train_loader = Data.DataLoader(train_data,batch_size=configs.batch_size,
                                 shuffle= True, num_workers= 4, pin_memory= True)

  val_data = voc_dates.VOCDataset(configs.val_img_dir,configs.val_label_dir,
                                  configs.val_txt_dir,'val',transform=True,
                                  transform_med = val_transform_det)
  val_loader = Data.DataLoader(val_data, batch_size= configs.batch_size,
                                shuffle= False, num_workers= 4, pin_memory= True)
  ######  build  models ########
  deeplab = models.deeplab()
  deeplab_pretrain_model = utils.load_deeplab_pretrain_model(pretrain_deeplab_path)
  deeplab.init_parameters(deeplab_pretrain_model)
  deeplab = deeplab.cuda()

  params = list(deeplab.parameters())
  #########

  if resume:
      checkpoint = torch.load(configs.best_ckpt_dir)
      deeplab.load_state_dict(checkpoint['state_dict'])
      print('resum sucess')

  ######### optimizer ##########
  ######## how to set different learning rate for differern layer #########
  optimizer = torch.optim.SGD(
      [
          {'params': get_parameters(deeplab, bias=False)},
          {'params': get_parameters(deeplab, bias=True),
           'lr': configs.learning_rate * 2, 'weight_decay': 0},
      ],lr=configs.learning_rate, momentum=configs.momentum,weight_decay=configs.weight_decay)

  ######## iter img_label pairs ###########

  for epoch in range(20):

      utils.adjust_learning_rate(configs.learning_rate,optimizer,epoch)
      for batch_idx, batch in enumerate(train_loader):

          img_idx, label_idx, filename,height,width = batch
          img,label = Variable(img_idx.cuda()),Variable(label_idx.cuda())
          prediction,weights = deeplab(img)
          loss = utils.cross_entropy2d(prediction,label,size_average=False)
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()

          if (batch_idx) % 20 == 0:
              print("Epoch [%d/%d] Loss: %.4f" % (epoch, batch_idx, loss.data[0]))

          if (batch_idx) % 4000 == 0:

              current_metric = validate(deeplab, val_loader, epoch)
              print current_metric

      current_metric = validate(deeplab, val_loader,epoch)

      if current_metric > best_metric:

         torch.save({'state_dict': deeplab.state_dict()},
                     os.path.join(configs.save_ckpt_dir, 'deeplab' + str(epoch) + '.pth'))

         shutil.copy(os.path.join(configs.save_ckpt_dir, 'deeplab' + str(epoch) + '.pth'),
                     os.path.join(configs.save_ckpt_dir, 'model_best.pth'))
         best_metric = current_metric

      if epoch % 5 == 0:
          torch.save({'state_dict': deeplab.state_dict()},
                       os.path.join(configs.save_ckpt_dir, 'deeplab' + str(epoch) + '.pth'))
示例#7
0
    def validate(self):
        """
        Function to validate a training model on the val split.
        """

        self.model.eval()
        val_loss = 0
        num_vis = 8
        visualizations = []
        label_trues, label_preds = [], []

        # Loop to forward pass the data points into the model and measure the performance
        for batch_idx, (data, target) in tqdm.tqdm(enumerate(self.val_loader),
                                                   total=len(self.val_loader),
                                                   desc='Valid iteration=%d' %
                                                   self.iteration,
                                                   ncols=80,
                                                   leave=False):
            if self.cuda:
                data, target = data.cuda(), target.cuda()

            data, target = Variable(data, volatile=True), Variable(target)
            score = self.model(data)
            loss = cross_entropy2d(score,
                                   target,
                                   size_average=self.size_average)
            if np.isnan(float(loss.data[0])):
                raise ValueError('loss is nan while validating')
            val_loss += float(loss.data[0]) / len(data)

            imgs = data.data.cpu()
            lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :]
            lbl_true = target.data.cpu()

            # Function to save visualizations of the predicted label map
            for img, lt, lp in zip(imgs, lbl_true, lbl_pred):
                img = self.val_loader.dataset.untransform(img.numpy())
                lt = lt.numpy()
                label_trues.append(lt)
                label_preds.append(lp)
                if len(visualizations) < num_vis:
                    viz = fcn.utils.visualize_segmentation(
                        lbl_pred=lp,
                        lbl_true=lt,
                        img=img,
                        n_class=self.n_class)
                    visualizations.append(viz)

        # Measuring the performance
        metrics = torchfcn.utils.label_accuracy_score(label_trues, label_preds,
                                                      self.n_class)

        out = osp.join(self.out, 'visualization_viz')
        if not osp.exists(out):
            os.makedirs(out)
        out_file = osp.join(out, 'iter%012d.jpg' % self.iteration)
        scipy.misc.imsave(out_file, fcn.utils.get_tile_image(visualizations))

        val_loss /= len(self.val_loader)

        # Logging
        with open(osp.join(self.out, 'log.csv'), 'a') as f:
            elapsed_time = \
                datetime.datetime.now(pytz.timezone('Asia/Tokyo')) - \
                self.timestamp_start
            log = [self.epoch, self.iteration] + [''] * 5 + \
                  [val_loss] + list(metrics) + [elapsed_time]
            log = map(str, log)
            f.write(','.join(log) + '\n')

        # Saving the model checkpoint
        mean_iu = metrics[2]
        is_best = mean_iu > self.best_mean_iu
        if is_best:
            self.best_mean_iu = mean_iu
        torch.save(
            {
                'epoch': self.epoch,
                'iteration': self.iteration,
                'arch': self.model.__class__.__name__,
                'optim_state_dict': self.optim.state_dict(),
                'model_state_dict': self.model.state_dict(),
                'best_mean_iu': self.best_mean_iu,
            }, osp.join(self.out, 'checkpoint.pth.tar'))
        if is_best:
            shutil.copy(osp.join(self.out, 'checkpoint.pth.tar'),
                        osp.join(self.out, 'model_best.pth.tar'))
示例#8
0
    def train_epoch(self):
        """
        Function to train the model for one epoch
        """

        self.model.train()

        # Loop for training the model
        for batch_idx, datas in tqdm.tqdm(enumerate(self.train_loader),
                                          total=len(self.train_loader),
                                          desc='Train epoch=%d' % self.epoch,
                                          ncols=80,
                                          leave=False):

            batch_size = 1
            iteration = batch_idx + self.epoch * len(self.train_loader)
            self.iteration = iteration
            if self.iteration % self.interval_validate == 0 and self.iteration > 0:
                self.validate()
            self.model.train()

            # Obtaining data in the right format
            data_source, labels_source = datas
            if self.cuda:
                data_source, labels_source = data_source.cuda(
                ), labels_source.cuda()
            data_source, labels_source = Variable(data_source), Variable(
                labels_source)

            # Forward pass
            self.optim.zero_grad()
            source_pred = self.model(data_source)

            # Computing the segmentation loss

            loss_seg = cross_entropy2d(source_pred,
                                       labels_source,
                                       size_average=self.size_average)
            loss_seg /= len(data_source)

            # Updating the model (backward pass)
            self.optim.zero_grad()
            loss_seg.backward()
            self.optim.step()

            if np.isnan(float(loss_seg.data[0])):
                raise ValueError('loss is nan while training')

            # Computing and logging performance metrics
            metrics = []
            lbl_pred = source_pred.data.max(1)[1].cpu().numpy()[:, :, :]
            lbl_true = labels_source.data.cpu().numpy()
            for lt, lp in zip(lbl_true, lbl_pred):
                acc, acc_cls, mean_iu, fwavacc = \
                    torchfcn.utils.label_accuracy_score(
                        [lt], [lp], n_class=self.n_class)
                metrics.append((acc, acc_cls, mean_iu, fwavacc))
            metrics = np.mean(metrics, axis=0)

            with open(osp.join(self.out, 'log.csv'), 'a') as f:
                elapsed_time = (
                    datetime.datetime.now(pytz.timezone('Asia/Tokyo')) -
                    self.timestamp_start).total_seconds()
                log = [self.epoch, self.iteration] + [loss_seg.data[0]] + \
                    metrics.tolist() + [''] * 5 + [elapsed_time]
                log = map(str, log)
                f.write(','.join(log) + '\n')

            if self.iteration >= self.max_iter:
                break
示例#9
0
    def train_epoch(self):
        """
        Function to train the model for one epoch
        """
        self.model.train()
        self.netG.train()
        self.netD.train()

        for batch_idx, (datas, datat) in tqdm.tqdm(
            enumerate(itertools.izip(self.train_loader, self.target_loader)), total=min(len(self.target_loader), len(self.train_loader)),
            desc='Train epoch = %d' % self.epoch, ncols=80, leave=False):

            data_source, labels_source = datas
            data_target, __ = datat
            data_source_forD = torch.zeros((data_source.size()[0], 3, self.image_size_forD[1], self.image_size_forD[0]))            
            data_target_forD = torch.zeros((data_target.size()[0], 3, self.image_size_forD[1], self.image_size_forD[0]))
            
            # We pass the unnormalized data to the discriminator. So, the GANs produce images without data normalization
            for i in range(data_source.size()[0]):
                data_source_forD[i] = self.train_loader.dataset.transform_forD(data_source[i], self.image_size_forD, resize=False, mean_add=True)
                data_target_forD[i] = self.train_loader.dataset.transform_forD(data_target[i], self.image_size_forD, resize=False, mean_add=True)

            iteration = batch_idx + self.epoch * min(len(self.train_loader), len(self.target_loader))
            self.iteration = iteration

            if self.cuda:
                data_source, labels_source = data_source.cuda(), labels_source.cuda()
                data_target = data_target.cuda()
                data_source_forD = data_source_forD.cuda()
                data_target_forD = data_target_forD.cuda()
            
            data_source, labels_source = Variable(data_source), Variable(labels_source)
            data_target = Variable(data_target)
            data_source_forD = Variable(data_source_forD)
            data_target_forD = Variable(data_target_forD)



            # Source domain 
            score, fc7, pool4, pool3 = self.model(data_source)
            outG_src = self.netG(fc7, pool4, pool3)
            outD_src_fake_s, outD_src_fake_c = self.netD(outG_src)
            outD_src_real_s, outD_src_real_c = self.netD(data_source_forD)
            
            # target domain
            tscore, tfc7, tpool4, tpool3= self.model(data_target)
            outG_tgt = self.netG(tfc7, tpool4, tpool3)
            outD_tgt_real_s, outD_tgt_real_c = self.netD(data_target_forD)
            outD_tgt_fake_s, outD_tgt_fake_c = self.netD(outG_tgt)

            # Creating labels for D. We need two sets of labels since our model is a ACGAN style framework.
            # (1) Labels for the classsifier branch. This will be a downsampled version of original segmentation labels
            # (2) Domain lables for classifying source real, source fake, target real and target fake
            
            # Labels for classifier branch 
            Dout_sz = outD_src_real_s.size()
            label_forD = torch.zeros((outD_tgt_fake_c.size()[0], outD_tgt_fake_c.size()[2], outD_tgt_fake_c.size()[3]))
            for i in range(label_forD.size()[0]):
                label_forD[i] = self.train_loader.dataset.transform_label_forD(labels_source[i], (outD_tgt_fake_c.size()[2], outD_tgt_fake_c.size()[3]))
            if self.cuda:
                label_forD = label_forD.cuda()
            label_forD = Variable(label_forD.long())

            # Domain labels
            domain_labels_src_real = torch.LongTensor(Dout_sz[0],Dout_sz[2],Dout_sz[3]).zero_()
            domain_labels_src_fake = torch.LongTensor(Dout_sz[0],Dout_sz[2],Dout_sz[3]).zero_()+1
            domain_labels_tgt_real = torch.LongTensor(Dout_sz[0],Dout_sz[2],Dout_sz[3]).zero_()+2
            domain_labels_tgt_fake = torch.LongTensor(Dout_sz[0],Dout_sz[2],Dout_sz[3]).zero_()+3

            domain_labels_src_real = Variable(domain_labels_src_real.cuda())
            domain_labels_src_fake = Variable(domain_labels_src_fake.cuda())
            domain_labels_tgt_real = Variable(domain_labels_tgt_real.cuda())
            domain_labels_tgt_fake = Variable(domain_labels_tgt_fake.cuda())

            
            # Updates.
            # There are three sets of updates - (1) Discriminator, (2) Generator and (3) F network
            
            # (1) Discriminator updates
            lossD_src_real_s = cross_entropy2d(outD_src_real_s, domain_labels_src_real, size_average=self.size_average)
            lossD_src_fake_s = cross_entropy2d(outD_src_fake_s, domain_labels_src_fake, size_average=self.size_average)
            lossD_src_real_c = cross_entropy2d(outD_src_real_c, label_forD, size_average=self.size_average)
            lossD_tgt_real = cross_entropy2d(outD_tgt_real_s, domain_labels_tgt_real, size_average=self.size_average)
            lossD_tgt_fake = cross_entropy2d(outD_tgt_fake_s, domain_labels_tgt_fake, size_average=self.size_average)           
            
            self.optimD.zero_grad()            
            lossD = lossD_src_real_s + lossD_src_fake_s + lossD_src_real_c + lossD_tgt_real + lossD_tgt_fake
            lossD /= len(data_source)
            lossD.backward(retain_graph=True)
            self.optimD.step()
        
            
            # (2) Generator updates
            self.optimG.zero_grad()            
            lossG_src_adv_s = cross_entropy2d(outD_src_fake_s, domain_labels_src_real,size_average=self.size_average)
            lossG_src_adv_c = cross_entropy2d(outD_src_fake_c, label_forD,size_average=self.size_average)
            lossG_tgt_adv_s = cross_entropy2d(outD_tgt_fake_s, domain_labels_tgt_real,size_average=self.size_average)
            lossG_src_mse = F.l1_loss(outG_src,data_source_forD)
            lossG_tgt_mse = F.l1_loss(outG_tgt,data_target_forD)

            lossG = lossG_src_adv_c + 0.1*(lossG_src_adv_s+ lossG_tgt_adv_s) + self.l1_weight * (lossG_src_mse + lossG_tgt_mse)
            lossG /= len(data_source)
            lossG.backward(retain_graph=True)
            self.optimG.step()

            # (3) F network updates 
            self.optim.zero_grad()            
            lossC = cross_entropy2d(score, labels_source,size_average=self.size_average)
            lossF_src_adv_s = cross_entropy2d(outD_src_fake_s, domain_labels_tgt_real,size_average=self.size_average)
            lossF_tgt_adv_s = cross_entropy2d(outD_tgt_fake_s, domain_labels_src_real,size_average=self.size_average)
            lossF_src_adv_c = cross_entropy2d(outD_src_fake_c, label_forD,size_average=self.size_average)
            
            lossF = lossC + self.adv_weight*(lossF_src_adv_s + lossF_tgt_adv_s) + self.c_weight*lossF_src_adv_c
            lossF /= len(data_source)
            lossF.backward()
            self.optim.step()
            
            if np.isnan(float(lossD.data[0])):
                raise ValueError('lossD is nan while training')
            if np.isnan(float(lossG.data[0])):
                raise ValueError('lossG is nan while training')
            if np.isnan(float(lossF.data[0])):
                raise ValueError('lossF is nan while training')
           
            # Computing metrics for logging
            metrics = []
            lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :]
            lbl_true = labels_source.data.cpu().numpy()
            for lt, lp in zip(lbl_true, lbl_pred):
                acc, acc_cls, mean_iu, fwavacc = \
                    torchfcn.utils.label_accuracy_score(
                        [lt], [lp], n_class=self.n_class)
                metrics.append((acc, acc_cls, mean_iu, fwavacc))
            metrics = np.mean(metrics, axis=0)

            # Logging
            with open(osp.join(self.out, 'log.csv'), 'a') as f:
                elapsed_time = (
                    datetime.datetime.now(pytz.timezone('Asia/Tokyo')) -
                    self.timestamp_start).total_seconds()
                log = [self.epoch, self.iteration] + [lossF.data[0]] + \
                    metrics.tolist() + [''] * 5 + [elapsed_time]
                log = map(str, log)
                f.write(','.join(log) + '\n')

            if self.iteration >= self.max_iter:
                break
            
            # Validating periodically
            if self.iteration % self.interval_validate == 0 and self.iteration > 0:
                out_recon = osp.join(self.out, 'visualization_viz')
                if not osp.exists(out_recon):
                    os.makedirs(out_recon)
                generations = []

                # Saving generated source and target images
                source_img = self.val_loader.dataset.untransform(data_source.data.cpu().numpy().squeeze())
                target_img = self.val_loader.dataset.untransform(data_target.data.cpu().numpy().squeeze())
                outG_src_ = (outG_src)*255.0
                outG_tgt_ = (outG_tgt)*255.0
                outG_src_ = outG_src_.data.cpu().numpy().squeeze().transpose((1,2,0))[:,:,::-1].astype(np.uint8)
                outG_tgt_ = outG_tgt_.data.cpu().numpy().squeeze().transpose((1,2,0))[:,:,::-1].astype(np.uint8)

                generations.append(source_img)
                generations.append(outG_src_)
                generations.append(target_img)
                generations.append(outG_tgt_)
                out_file = osp.join(out_recon, 'iter%012d_src_target_recon.png' % self.iteration)
                scipy.misc.imsave(out_file, fcn.utils.get_tile_image(generations))

                # Validation
                self.validate()
                self.model.train()
                self.netG.train()
示例#10
0
    def validate(self):
        """
        Function to validate a training model on the val split.
        """
        
        self.model.eval()
        self.netG.eval()

        val_loss = 0
        num_vis = 8
        visualizations = []
        generations = []
        label_trues, label_preds = [], []
        
        # Evaluation
        for batch_idx, (data, target) in tqdm.tqdm(
            enumerate(self.val_loader), total=len(self.val_loader),
            desc='Validation iteration = %d' % self.iteration, ncols=80,
            leave=False):

            if self.cuda:
                data, target = data.cuda(), target.cuda()
            data, target = Variable(data, volatile=True), Variable(target)
            
            score, fc7, pool4, pool3 = self.model(data)
            outG = self.netG(fc7, pool4, pool3)

            loss = cross_entropy2d(score, target, size_average=self.size_average)
            if np.isnan(float(loss.data[0])):
                raise ValueError('loss is nan while validating')
            val_loss += float(loss.data[0]) / len(data)

            imgs = data.data.cpu()
            lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :]
            lbl_true = target.data.cpu()
            
            # Visualizing predicted labels
            for img, lt, lp , outG_ in zip(imgs, lbl_true, lbl_pred,outG):
                
                outG_ = outG_*255.0
                outG_ = outG_.data.cpu().numpy().squeeze().transpose((1,2,0))[:,:,::-1].astype(np.uint8)
                img = self.val_loader.dataset.untransform(img.numpy())
                lt = lt.numpy()
                label_trues.append(lt)
                label_preds.append(lp)
                if len(visualizations) < num_vis:
                    viz = fcn.utils.visualize_segmentation(
                        lbl_pred=lp, lbl_true=lt, img=img, n_class=self.n_class)
                    visualizations.append(viz)
                    generations.append(outG_)
        
        # Computing the metrics
        metrics = torchfcn.utils.label_accuracy_score(
            label_trues, label_preds, self.n_class)
        val_loss /= len(self.val_loader)

        # Saving the label visualizations and generations
        out = osp.join(self.out, 'visualization_viz')
        if not osp.exists(out):
            os.makedirs(out)
        out_file = osp.join(out, 'iter%012d_labelmap.jpg' % self.iteration)
        scipy.misc.imsave(out_file, fcn.utils.get_tile_image(visualizations))
        out_file = osp.join(out, 'iter%012d_generations.jpg' % self.iteration)
        scipy.misc.imsave(out_file, fcn.utils.get_tile_image(generations))

        # Logging
        with open(osp.join(self.out, 'log.csv'), 'a') as f:
         elapsed_time = \
             datetime.datetime.now(pytz.timezone('Asia/Tokyo')) - \
             self.timestamp_start
         log = [self.epoch, self.iteration] + [''] * 5 + \
               [val_loss] + list(metrics) + [elapsed_time]
         log = map(str, log)
         f.write(','.join(log) + '\n')

        # Saving the models
        mean_iu = metrics[2]
        is_best = mean_iu > self.best_mean_iu
        if is_best:
            self.best_mean_iu = mean_iu
        torch.save({
         'epoch': self.epoch,
         'iteration': self.iteration,
         'arch': self.model.__class__.__name__,
         'optim_state_dict': self.optim.state_dict(),
         'model_state_dict': self.model.state_dict(),
         'best_mean_iu': self.best_mean_iu,
        }, osp.join(self.out, 'checkpoint.pth.tar'))
        if is_best:
            shutil.copy(osp.join(self.out, 'checkpoint.pth.tar'),
                     osp.join(self.out, 'model_best.pth.tar'))
示例#11
0
    def train_epoch(self):
        """
        Function to train the model for one epoch
        """
        
        self.model.train()

        # Loop for training the model
        for batch_idx, datas in tqdm.tqdm(
                enumerate(self.train_loader), total= len(self.train_loader),
                desc='Train epoch=%d' % self.epoch, ncols=80, leave=False):
            
            batch_size = 1
            iteration = batch_idx + self.epoch * len(self.train_loader)
            self.iteration = iteration
            if self.iteration % self.interval_validate == 0 and self.iteration>0:
                self.validate()
            self.model.train()

            # Obtaining data in the right format
            data_source, labels_source = datas
            if self.cuda:
                data_source, labels_source = data_source.cuda(), labels_source.cuda()
            data_source, labels_source = Variable(data_source), Variable(labels_source)
            
            # Forward pass
            self.optim.zero_grad()
            source_pred = self.model(data_source)
            
            # Computing the segmentation loss
            
            loss_seg = cross_entropy2d(source_pred, labels_source, size_average=self.size_average)
            loss_seg /= len(data_source)
            
            # Updating the model (backward pass)
            self.optim.zero_grad()
            loss_seg.backward()
            self.optim.step()
            
            if np.isnan(float(loss_seg.data[0])):
                raise ValueError('loss is nan while training')

            # Computing and logging performance metrics
            metrics = []
            lbl_pred = source_pred.data.max(1)[1].cpu().numpy()[:, :, :]
            lbl_true = labels_source.data.cpu().numpy()
            for lt, lp in zip(lbl_true, lbl_pred):
                acc, acc_cls, mean_iu, fwavacc = \
                    torchfcn.utils.label_accuracy_score(
                        [lt], [lp], n_class=self.n_class)
                metrics.append((acc, acc_cls, mean_iu, fwavacc))
            metrics = np.mean(metrics, axis=0)

            with open(osp.join(self.out, 'log.csv'), 'a') as f:
                elapsed_time = (
                    datetime.datetime.now(pytz.timezone('Asia/Tokyo')) -
                    self.timestamp_start).total_seconds()
                log = [self.epoch, self.iteration] + [loss_seg.data[0]] + \
                    metrics.tolist() + [''] * 5 + [elapsed_time]
                log = map(str, log)
                f.write(','.join(log) + '\n')

            if self.iteration >= self.max_iter:
                break
    def test(self):
        training = self.model.training
        self.model.eval()

        n_class = len(self.test_loader.dataset.class_names)

        test_loss = 0
        visualizations = []
        label_trues, label_preds = [], []
        for batch_idx, (data, target) in tqdm.tqdm(enumerate(self.test_loader),
                                                   total=len(self.test_loader),
                                                   desc='Test iteration=%d' %
                                                   self.iteration,
                                                   ncols=80,
                                                   leave=False):
            if self.cuda:
                data, target = data.cuda(), target.cuda()
            data, target = Variable(data, volatile=True), Variable(target)
            score = self.model(data)

            loss = utils.cross_entropy2d(score,
                                         target,
                                         size_average=self.size_average)
            loss_data = float(loss.data[0])
            if np.isnan(loss_data):
                raise ValueError('loss is nan while testing')
            test_loss += loss_data / len(data)

            imgs = data.data.cpu()
            lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :].astype(
                np.uint8)
            lbl_true = target.data.cpu().numpy()
            for img, lt, lp in zip(imgs, lbl_true, lbl_pred):
                img, lt = self.test_loader.dataset.untransform(img, lt)
                label_trues.append(lt)
                label_preds.append(lp)
                if len(visualizations) < 9:
                    viz = fcn.utils.visualize_segmentation(lbl_pred=lp,
                                                           lbl_true=lt,
                                                           img=img,
                                                           n_class=n_class)
                    visualizations.append(viz)
        metrics = utils.label_accuracy_score(label_trues, label_preds, n_class)

        out = osp.join(self.out, 'visualization_viz')
        if not osp.exists(out):
            os.makedirs(out)
        out_file = osp.join(out, 'iter_test_%012d.jpg' % self.iteration)
        scipy.misc.imsave(out_file, fcn.utils.get_tile_image(visualizations))

        test_loss /= len(self.test_loader)

        with open(osp.join(self.out, 'log.csv'), 'a') as f:
            elapsed_time = (
                datetime.datetime.now(pytz.timezone('Asia/Tokyo')) -
                self.timestamp_start).total_seconds()
            log = [self.epoch, self.iteration
                   ] + [''] * 5 + [test_loss] + list(metrics) + [elapsed_time]
            log = map(str, log)
            f.write(','.join(log) + '\n')

        # logging information for tensorboard
        info = OrderedDict({
            "loss": test_loss,
            "acc": metrics[0],
            "acc_cls": metrics[1],
            "meanIoU": metrics[2],
            "fwavacc": metrics[3],
            "bestIoU": self.best_mean_iu,
        })
        len(self.train_loader)
        # msg = "\t".join([key + ":" + "%.4f" % value for key, value in info.items()])
        partial_epoch = self.iteration / len(self.train_loader)
        for tag, value in info.items():
            self.ts_logger.scalar_summary(tag, value, partial_epoch)

        if training:
            self.model.train()
    def train_epoch(self):
        self.model.train()

        n_class = len(self.train_loader.dataset.class_names)

        for batch_idx, (data, target) in tqdm.tqdm(
                enumerate(self.train_loader),
                total=len(self.train_loader),
                desc='Train epoch=%d' % self.epoch,
                ncols=80,
                leave=False):
            iteration = batch_idx + self.epoch * len(self.train_loader)
            if self.iteration != 0 and (iteration - 1) != self.iteration:
                continue  # for resuming
            self.iteration = iteration

            if self.iteration % self.interval_validate == 0:
                self.validate()
                self.test()

            assert self.model.training

            if self.cuda:
                data, target = data.cuda(), target.cuda()
            data, target = Variable(data), Variable(target)
            self.optim.zero_grad()
            score = self.model(data)
            weights = torch.from_numpy(
                self.train_loader.dataset.class_weights).float().cuda()
            ignore = self.train_loader.dataset.class_ignore
            loss = utils.cross_entropy2d(score,
                                         target,
                                         weight=weights,
                                         size_average=self.size_average,
                                         ignore=ignore)
            loss /= len(data)
            loss_data = float(loss.data[0])
            if np.isnan(loss_data):
                raise ValueError('loss is nan while training')
            loss.backward()
            self.optim.step()

            metrics = []
            lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :]
            lbl_true = target.data.cpu().numpy()
            acc, acc_cls, mean_iu, fwavacc = utils.label_accuracy_score(
                lbl_true, lbl_pred, n_class=n_class)
            metrics.append((acc, acc_cls, mean_iu, fwavacc))
            metrics = np.mean(metrics, axis=0)

            with open(osp.join(self.out, 'log.csv'), 'a') as f:
                elapsed_time = (
                    datetime.datetime.now(pytz.timezone('Asia/Tokyo')) -
                    self.timestamp_start).total_seconds()
                log = [self.epoch, self.iteration] + [loss_data] + \
                    metrics.tolist() + [''] * 5 + [elapsed_time]
                log = map(str, log)
                f.write(','.join(log) + '\n')

            # logging to tensorboard
            self.best_train_meanIoU = max(self.best_train_meanIoU, metrics[2])
            info = OrderedDict({
                "loss": loss.data[0],
                "acc": metrics[0],
                "acc_cls": metrics[1],
                "meanIoU": metrics[2],
                "fwavacc": metrics[3],
                "bestIoU": self.best_train_meanIoU,
            })
            partialEpoch = self.epoch + float(batch_idx) / len(
                self.train_loader)
            for tag, value in info.items():
                self.t_logger.scalar_summary(tag, value, partialEpoch)

            if self.iteration >= self.max_iter:
                break
def main():

    #########  configs ###########
    best_metric = 0
    pretrain_vgg16_path = os.path.join(configs.py_dir,
                                       'model/vgg16_from_caffe.pth')

    ######  load datasets ########

    train_data = voc_dates.VOCDataset(configs.train_img_dir,
                                      configs.train_label_dir,
                                      configs.train_txt_dir,
                                      'train',
                                      transform=True)
    train_loader = Data.DataLoader(train_data,
                                   batch_size=configs.batch_size,
                                   shuffle=True,
                                   num_workers=4,
                                   pin_memory=True)

    val_data = voc_dates.VOCDataset(configs.val_img_dir,
                                    configs.val_label_dir,
                                    configs.val_txt_dir,
                                    'val',
                                    transform=True)
    val_loader = Data.DataLoader(val_data,
                                 batch_size=configs.batch_size,
                                 shuffle=False,
                                 num_workers=4,
                                 pin_memory=True)
    ######  build  models ########
    fcn32s = models.fcn32s()
    vgg_pretrain_model = utils.load_pretrain_model(pretrain_vgg16_path)
    fcn32s.init_parameters(vgg_pretrain_model)
    fcn32s = fcn32s.cuda()
    #########

    if resume:
        checkpoint = torch.load(configs.best_ckpt_dir)
        fcn32s.load_state_dict(checkpoint['state_dict'])
        print('resum sucess')

    ######### optimizer ##########
    ######## how to set different learning rate for differern layer #########
    optimizer = torch.optim.SGD([
        {
            'params': get_parameters(fcn32s, bias=False)
        },
        {
            'params': get_parameters(fcn32s, bias=True),
            'lr': configs.learning_rate * 2,
            'weight_decay': 0
        },
    ],
                                lr=configs.learning_rate,
                                momentum=configs.momentum,
                                weight_decay=configs.weight_decay)

    ######## iter img_label pairs ###########

    for epoch in range(20):

        utils.adjust_learning_rate(configs.learning_rate, optimizer, epoch)
        for batch_idx, (img_idx, label_idx) in enumerate(train_loader):

            img, label = Variable(img_idx.cuda()), Variable(label_idx.cuda())
            prediction = fcn32s(img)
            loss = utils.cross_entropy2d(prediction, label, size_average=False)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (batch_idx) % 20 == 0:
                print("Epoch [%d/%d] Loss: %.4f" %
                      (epoch, batch_idx, loss.data[0]))

        current_metric = validate(fcn32s, val_loader, epoch)

        if current_metric > best_metric:

            torch.save({'state_dict': fcn32s.state_dict()},
                       os.path.join(configs.save_ckpt_dir,
                                    'fcn32s' + str(epoch) + '.pth'))

            shutil.copy(
                os.path.join(configs.save_ckpt_dir,
                             'fcn32s' + str(epoch) + '.pth'),
                os.path.join(configs.save_ckpt_dir, 'model_best.pth'))
            best_metric = current_metric

        if epoch % 5 == 0:
            torch.save({'state_dict': fcn32s.state_dict()},
                       os.path.join(configs.save_ckpt_dir,
                                    'fcn32s' + str(epoch) + '.pth'))
示例#15
0
#                 ], lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
optimizer = torch.optim.SGD(model.parameters(),
                            lr=alearning_rate,
                            momentum=momentum,
                            weight_decay=weight_decay)

scheduler = lr_scheduler.MultiStepLR(optimizer,
                                     milestones=[
                                         int(0.4 * end_epoch),
                                         int(0.7 * end_epoch),
                                         int(0.8 * end_epoch),
                                         int(0.9 * end_epoch)
                                     ],
                                     gamma=0.1)
# scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',patience=10, verbose=True)
criterion = cross_entropy2d()

# resume
if (os.path.isfile(resume_path) and resume_flag):
    checkpoint = torch.load(resume_path)
    model.load_state_dict(checkpoint["model_state"])
    optimizer.load_state_dict(checkpoint["optimizer_state"])
    best_iou = checkpoint['best_iou']
    # scheduler.load_state_dict(checkpoint["scheduler_state"])
    # start_epoch = checkpoint["epoch"]
    print(
        "=====>",
        "Loaded checkpoint '{}' (iter {})".format(resume_path,
                                                  checkpoint["epoch"]))
else:
    print("=====>", "No checkpoint found at '{}'".format(resume_path))