def forward(self, data, target): target, target_embed = target target = target.numpy() # reshape target into binary seenmask seen = [x for x in range(self.n_class) if x not in self.unseen] target = np.in1d(target.ravel(), seen).reshape(target.shape).astype(int) target = torch.from_numpy(target) if self.cuda: data, target = data.cuda(), target.cuda() data, target = Variable(data), Variable(target) score = self.model(data, mode='seenmask') loss = utils.cross_entropy2d(score, target, size_average=True) lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :] lbl_true = target.data.cpu() return score, loss, lbl_pred, lbl_true
def forward(self, data, target): # get score if self.pixel_embeddings: target, target_embed = target if self.cuda: data, target = data.cuda(), target.cuda() data, target = Variable(data), Variable(target) if self.pixel_embeddings: if self.cuda: target_embed = target_embed.cuda() target_embed = Variable(target_embed) score = self.model(data, mode='fcn') # get loss if self.loss_func == "cos": loss = utils.cosine_loss(score, target, target_embed) elif self.loss_func == "mse": loss = utils.mse_loss(score, target, target_embed) elif self.loss_func == "cross_entropy": loss = utils.cross_entropy2d(score, target, size_average=False) if np.isnan(float(loss.data[0])): raise ValueError('loss is nan while training') # inference if self.pixel_embeddings: if self.forced_unseen: lbl_pred = utils.infer_lbl_forced_unseen( score, target, self.seen_embeddings, self.unseen_embeddings, self.unseen, self.cuda) else: lbl_pred = utils.infer_lbl(score, self.embeddings, self.cuda) else: lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :] lbl_true = target.data.cpu() return score, loss, lbl_pred, lbl_true
def validate(self): """ Function to validate a training model on the val split. """ self.model.eval() self.netG.eval() val_loss = 0 num_vis = 8 visualizations = [] generations = [] label_trues, label_preds = [], [] # Evaluation for batch_idx, (data, target) in tqdm.tqdm( enumerate(self.val_loader), total=len(self.val_loader), desc='Validation iteration = %d' % self.iteration): if self.cuda: data, target = data.cuda(), target.cuda() data, target = Variable(data, volatile=True), Variable(target) score, fc7, pool4, pool3 = self.model(data) outG = self.netG(fc7, pool4, pool3) loss = cross_entropy2d(score, target, size_average=self.size_average) if np.isnan(float(loss.data[0])): raise ValueError('loss is nan while validating') val_loss += float(loss.data[0]) / len(data) imgs = data.data.cpu() lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :] lbl_true = target.data.cpu() # Visualizing predicted labels for img, lt, lp , outG_ in zip(imgs, lbl_true, lbl_pred,outG): outG_ = outG_*255.0 outG_ = outG_.data.cpu().numpy().squeeze().transpose((1,2,0))[:,:,::-1].astype(np.uint8) img = self.val_loader.dataset.untransform(img.numpy()) lt = lt.numpy() label_trues.append(lt) label_preds.append(lp) if len(visualizations) < num_vis: lt[lt >= CLASS_NUM] = -1# to make fcn.utils.visualize_segmentation work! viz = fcn.utils.visualize_segmentation( lbl_pred=lp, lbl_true=lt, img=img, n_class=self.n_class) visualizations.append(viz) generations.append(outG_) # Computing the metrics metrics = torchfcn.utils.label_accuracy_score( label_trues, label_preds, self.n_class) val_loss /= len(self.val_loader) # Saving the label visualizations and generations out = osp.join(self.out, 'visualization_viz') if not osp.exists(out): os.makedirs(out) out_file = osp.join(out, 'iter%012d_labelmap.jpg' % self.iteration) scipy.misc.imsave(out_file, fcn.utils.get_tile_image(visualizations)) out_file = osp.join(out, 'iter%012d_generations.jpg' % self.iteration) scipy.misc.imsave(out_file, fcn.utils.get_tile_image(generations)) # Logging logger.info("validation mIoU: {}".format(metrics[2])) with open(osp.join(self.out, 'log.csv'), 'a') as f: elapsed_time = \ datetime.datetime.now(pytz.timezone('Asia/Tokyo')) - \ self.timestamp_start log = [self.epoch, self.iteration] + [''] * 5 + \ [val_loss] + list(metrics) + [elapsed_time] log = map(str, log) f.write(','.join(log) + '\n') # Saving the models mean_iu = metrics[2] is_best = mean_iu > self.best_mean_iu if is_best: self.best_mean_iu = mean_iu torch.save({ 'epoch': self.epoch, 'iteration': self.iteration, 'arch': self.model.__class__.__name__, 'optim_state_dict': self.optim.state_dict(), 'model_state_dict': self.model.state_dict(), 'best_mean_iu': self.best_mean_iu, }, osp.join(self.out, 'checkpoint.pth.tar')) if is_best: shutil.copy(osp.join(self.out, 'checkpoint.pth.tar'), osp.join(self.out, 'model_best.pth.tar'))
def train_epoch(self): """ Function to train the model for one epoch """ self.model.train() self.netG.train() self.netD.train() for batch_idx, (datas, datat) in tqdm.tqdm( enumerate(itertools.izip(self.train_loader, self.target_loader)), total=min(len(self.target_loader), len(self.train_loader)), desc='Train epoch = {}/{}'.format(self.epoch, self.max_epoch)): data_source, labels_source = datas data_target, __ = datat data_source_forD = torch.zeros((data_source.size()[0], 3, self.image_size_forD[1], self.image_size_forD[0])) data_target_forD = torch.zeros((data_target.size()[0], 3, self.image_size_forD[1], self.image_size_forD[0])) # We pass the unnormalized data to the discriminator. So, the GANs produce images without data normalization for i in range(data_source.size()[0]): data_source_forD[i] = self.train_loader.dataset.transform_forD(data_source[i], self.image_size_forD, resize=False, mean_add=True) data_target_forD[i] = self.train_loader.dataset.transform_forD(data_target[i], self.image_size_forD, resize=False, mean_add=True) iteration = batch_idx + self.epoch * min(len(self.train_loader), len(self.target_loader)) self.iteration = iteration if self.cuda: data_source, labels_source = data_source.cuda(), labels_source.cuda() data_target = data_target.cuda() data_source_forD = data_source_forD.cuda() data_target_forD = data_target_forD.cuda() data_source, labels_source = Variable(data_source), Variable(labels_source) data_target = Variable(data_target) data_source_forD = Variable(data_source_forD) data_target_forD = Variable(data_target_forD) # Source domain score, fc7, pool4, pool3 = self.model(data_source) outG_src = self.netG(fc7, pool4, pool3) outD_src_fake_s, outD_src_fake_c = self.netD(outG_src) outD_src_real_s, outD_src_real_c = self.netD(data_source_forD) # target domain tscore, tfc7, tpool4, tpool3= self.model(data_target) outG_tgt = self.netG(tfc7, tpool4, tpool3) outD_tgt_real_s, outD_tgt_real_c = self.netD(data_target_forD) outD_tgt_fake_s, outD_tgt_fake_c = self.netD(outG_tgt) # Creating labels for D. We need two sets of labels since our model is a ACGAN style framework. # (1) Labels for the classsifier branch. This will be a downsampled version of original segmentation labels # (2) Domain lables for classifying source real, source fake, target real and target fake # Labels for classifier branch Dout_sz = outD_src_real_s.size() label_forD = torch.zeros((outD_tgt_fake_c.size()[0], outD_tgt_fake_c.size()[2], outD_tgt_fake_c.size()[3]))#[1,40,80] for i in range(label_forD.size()[0]): label_forD[i] = self.train_loader.dataset.transform_label_forD(labels_source[i], (outD_tgt_fake_c.size()[2], outD_tgt_fake_c.size()[3])) if self.cuda: label_forD = label_forD.cuda() label_forD = Variable(label_forD.long()) # Domain labels domain_labels_src_real = torch.LongTensor(Dout_sz[0],Dout_sz[2],Dout_sz[3]).zero_() domain_labels_src_fake = torch.LongTensor(Dout_sz[0],Dout_sz[2],Dout_sz[3]).zero_()+1 domain_labels_tgt_real = torch.LongTensor(Dout_sz[0],Dout_sz[2],Dout_sz[3]).zero_()+2 domain_labels_tgt_fake = torch.LongTensor(Dout_sz[0],Dout_sz[2],Dout_sz[3]).zero_()+3 domain_labels_src_real = Variable(domain_labels_src_real.cuda()) domain_labels_src_fake = Variable(domain_labels_src_fake.cuda()) domain_labels_tgt_real = Variable(domain_labels_tgt_real.cuda()) domain_labels_tgt_fake = Variable(domain_labels_tgt_fake.cuda()) # Updates. # There are three sets of updates - (1) Discriminator, (2) Generator and (3) F network # (1) Discriminator updates lossD_src_real_s = cross_entropy2d(outD_src_real_s, domain_labels_src_real, size_average=self.size_average) lossD_src_fake_s = cross_entropy2d(outD_src_fake_s, domain_labels_src_fake, size_average=self.size_average) lossD_src_real_c = cross_entropy2d(outD_src_real_c, label_forD, size_average=self.size_average)#TODO,buggy lossD_tgt_real = cross_entropy2d(outD_tgt_real_s, domain_labels_tgt_real, size_average=self.size_average) lossD_tgt_fake = cross_entropy2d(outD_tgt_fake_s, domain_labels_tgt_fake, size_average=self.size_average) self.optimD.zero_grad() lossD = lossD_src_real_s + lossD_src_fake_s + lossD_src_real_c + lossD_tgt_real + lossD_tgt_fake lossD /= len(data_source) lossD.backward(retain_graph=True) self.optimD.step() # (2) Generator updates self.optimG.zero_grad() lossG_src_adv_s = cross_entropy2d(outD_src_fake_s, domain_labels_src_real,size_average=self.size_average) lossG_src_adv_c = cross_entropy2d(outD_src_fake_c, label_forD,size_average=self.size_average) lossG_tgt_adv_s = cross_entropy2d(outD_tgt_fake_s, domain_labels_tgt_real,size_average=self.size_average) lossG_src_mse = F.l1_loss(outG_src,data_source_forD) lossG_tgt_mse = F.l1_loss(outG_tgt,data_target_forD) lossG = lossG_src_adv_c + 0.1*(lossG_src_adv_s+ lossG_tgt_adv_s) + self.l1_weight * (lossG_src_mse + lossG_tgt_mse) lossG /= len(data_source) lossG.backward(retain_graph=True) self.optimG.step() # (3) F network updates self.optim.zero_grad() lossC = cross_entropy2d(score, labels_source,size_average=self.size_average) lossF_src_adv_s = cross_entropy2d(outD_src_fake_s, domain_labels_tgt_real,size_average=self.size_average) lossF_tgt_adv_s = cross_entropy2d(outD_tgt_fake_s, domain_labels_src_real,size_average=self.size_average) lossF_src_adv_c = cross_entropy2d(outD_src_fake_c, label_forD,size_average=self.size_average) lossF = lossC + self.adv_weight*(lossF_src_adv_s + lossF_tgt_adv_s) + self.c_weight*lossF_src_adv_c lossF /= len(data_source) lossF.backward() self.optim.step() if np.isnan(float(lossD.data[0])): raise ValueError('lossD is nan while training') if np.isnan(float(lossG.data[0])): raise ValueError('lossG is nan while training') if np.isnan(float(lossF.data[0])): raise ValueError('lossF is nan while training') # Computing metrics for logging metrics = [] lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :] lbl_true = labels_source.data.cpu().numpy() for lt, lp in zip(lbl_true, lbl_pred): acc, acc_cls, mean_iu, fwavacc = \ torchfcn.utils.label_accuracy_score( [lt], [lp], n_class=self.n_class) metrics.append((acc, acc_cls, mean_iu, fwavacc)) metrics = np.mean(metrics, axis=0) # Logging if self.iteration%100 == 0: logger.info("epoch: {}/{}, iteration:{}, lossF:{}, mIoU :{}".format(self.epoch, self.max_epoch, self.iteration,lossF.data[0], metrics[2])) with open(osp.join(self.out, 'log.csv'), 'a') as f: elapsed_time = ( datetime.datetime.now(pytz.timezone('Asia/Tokyo')) - self.timestamp_start).total_seconds() log = [self.epoch, self.iteration] + [lossF.data[0]] + \ metrics.tolist() + [''] * 5 + [elapsed_time] log = map(str, log) f.write(','.join(log) + '\n') if self.iteration >= self.max_iter: break # Validating periodically if self.iteration % self.interval_validate == 0 and self.iteration > 0: out_recon = osp.join(self.out, 'visualization_viz') if not osp.exists(out_recon): os.makedirs(out_recon) generations = [] # Saving generated source and target images source_img = self.val_loader.dataset.untransform(data_source.data.cpu().numpy().squeeze()) target_img = self.val_loader.dataset.untransform(data_target.data.cpu().numpy().squeeze()) outG_src_ = (outG_src)*255.0 outG_tgt_ = (outG_tgt)*255.0 outG_src_ = outG_src_.data.cpu().numpy().squeeze().transpose((1,2,0))[:,:,::-1].astype(np.uint8) outG_tgt_ = outG_tgt_.data.cpu().numpy().squeeze().transpose((1,2,0))[:,:,::-1].astype(np.uint8) generations.append(source_img) generations.append(outG_src_) generations.append(target_img) generations.append(outG_tgt_) out_file = osp.join(out_recon, 'iter%012d_src_target_recon.png' % self.iteration) scipy.misc.imsave(out_file, fcn.utils.get_tile_image(generations)) # Validation self.validate() self.model.train() self.netG.train()
def train(args, out, net_name): data_path = get_data_path(args.dataset) data_loader = get_loader(args.dataset) loader = data_loader(data_path, is_transform=True) n_classes = loader.n_classes print(n_classes) kwargs = {'num_workers': 8, 'pin_memory': True} trainloader = data.DataLoader(loader, batch_size=args.batch_size, shuffle=True) another_loader = data_loader(data_path, split='val', is_transform=True) valloader = data.DataLoader(another_loader, batch_size=args.batch_size, shuffle=True) # compute weight for cross_entropy2d norm_hist = hist / np.max(hist) weight = 1 / np.log(norm_hist + 1.02) weight[-1] = 0 weight = torch.FloatTensor(weight) model = Bilinear_Res(n_classes) if torch.cuda.is_available(): model.cuda(0) weight = weight.cuda(0) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr_rate, weight_decay=args.w_decay) # optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr_rate) scheduler = StepLR(optimizer, step_size=100, gamma=args.lr_decay) for epoch in tqdm.tqdm(range(args.epochs), desc='Training', ncols=80, leave=False): scheduler.step() model.train() loss_list = [] file = open(out + '/{}_epoch_{}.txt'.format(net_name, epoch), 'w') for i, (images, labels) in tqdm.tqdm(enumerate(trainloader), total=len(trainloader), desc='Iteration', ncols=80, leave=False): if torch.cuda.is_available(): images = Variable(images.cuda(0)) labels = Variable(labels.cuda(0)) else: images = Variable(images) labels = Variable(labels) optimizer.zero_grad() outputs = model(images) loss = cross_entropy2d(outputs, labels, weight=weight) loss_list.append(loss.data[0]) loss.backward() optimizer.step() # file.write(str(np.average(loss_list))) print(np.average(loss_list)) file.write(str(np.average(loss_list)) + '\n') model.eval() gts, preds = [], [] if (epoch % 10 == 0): for i, (images, labels) in tqdm.tqdm(enumerate(valloader), total=len(valloader), desc='Valid Iteration', ncols=80, leave=False): if torch.cuda.is_available(): images = Variable(images.cuda(0)) labels = Variable(labels.cuda(0)) else: images = Variable(images) labels = Variable(labels) outputs = model(images) pred = outputs.data.max(1)[1].cpu().numpy() gt = labels.data.cpu().numpy() for gt_, pred_ in zip(gt, pred): gts.append(gt_) preds.append(pred_) score, class_iou = scores(gts, preds, n_class=n_classes) for k, v in score.items(): file.write('{} {}\n'.format(k, v)) for i in range(n_classes): file.write('{} {}\n'.format(i, class_iou[i])) torch.save( model.state_dict(), out + "/{}_{}_{}.pkl".format(net_name, args.dataset, epoch)) file.close()
def main(): ######### configs ########### best_metric = 0 pretrain_deeplab_path = os.path.join(configs.py_dir, 'model/deeplab_coco.pth') ###### load datasets ######## train_transform_det = trans.Compose([ trans.Scale((321, 321)), ]) val_transform_det = trans.Compose([ trans.Scale((321,321)), ]) train_data = voc_dates.VOCDataset(configs.train_img_dir,configs.train_label_dir, configs.train_txt_dir,'train',transform=True, transform_med = train_transform_det) train_loader = Data.DataLoader(train_data,batch_size=configs.batch_size, shuffle= True, num_workers= 4, pin_memory= True) val_data = voc_dates.VOCDataset(configs.val_img_dir,configs.val_label_dir, configs.val_txt_dir,'val',transform=True, transform_med = val_transform_det) val_loader = Data.DataLoader(val_data, batch_size= configs.batch_size, shuffle= False, num_workers= 4, pin_memory= True) ###### build models ######## deeplab = models.deeplab() deeplab_pretrain_model = utils.load_deeplab_pretrain_model(pretrain_deeplab_path) deeplab.init_parameters(deeplab_pretrain_model) deeplab = deeplab.cuda() params = list(deeplab.parameters()) ######### if resume: checkpoint = torch.load(configs.best_ckpt_dir) deeplab.load_state_dict(checkpoint['state_dict']) print('resum sucess') ######### optimizer ########## ######## how to set different learning rate for differern layer ######### optimizer = torch.optim.SGD( [ {'params': get_parameters(deeplab, bias=False)}, {'params': get_parameters(deeplab, bias=True), 'lr': configs.learning_rate * 2, 'weight_decay': 0}, ],lr=configs.learning_rate, momentum=configs.momentum,weight_decay=configs.weight_decay) ######## iter img_label pairs ########### for epoch in range(20): utils.adjust_learning_rate(configs.learning_rate,optimizer,epoch) for batch_idx, batch in enumerate(train_loader): img_idx, label_idx, filename,height,width = batch img,label = Variable(img_idx.cuda()),Variable(label_idx.cuda()) prediction,weights = deeplab(img) loss = utils.cross_entropy2d(prediction,label,size_average=False) optimizer.zero_grad() loss.backward() optimizer.step() if (batch_idx) % 20 == 0: print("Epoch [%d/%d] Loss: %.4f" % (epoch, batch_idx, loss.data[0])) if (batch_idx) % 4000 == 0: current_metric = validate(deeplab, val_loader, epoch) print current_metric current_metric = validate(deeplab, val_loader,epoch) if current_metric > best_metric: torch.save({'state_dict': deeplab.state_dict()}, os.path.join(configs.save_ckpt_dir, 'deeplab' + str(epoch) + '.pth')) shutil.copy(os.path.join(configs.save_ckpt_dir, 'deeplab' + str(epoch) + '.pth'), os.path.join(configs.save_ckpt_dir, 'model_best.pth')) best_metric = current_metric if epoch % 5 == 0: torch.save({'state_dict': deeplab.state_dict()}, os.path.join(configs.save_ckpt_dir, 'deeplab' + str(epoch) + '.pth'))
def validate(self): """ Function to validate a training model on the val split. """ self.model.eval() val_loss = 0 num_vis = 8 visualizations = [] label_trues, label_preds = [], [] # Loop to forward pass the data points into the model and measure the performance for batch_idx, (data, target) in tqdm.tqdm(enumerate(self.val_loader), total=len(self.val_loader), desc='Valid iteration=%d' % self.iteration, ncols=80, leave=False): if self.cuda: data, target = data.cuda(), target.cuda() data, target = Variable(data, volatile=True), Variable(target) score = self.model(data) loss = cross_entropy2d(score, target, size_average=self.size_average) if np.isnan(float(loss.data[0])): raise ValueError('loss is nan while validating') val_loss += float(loss.data[0]) / len(data) imgs = data.data.cpu() lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :] lbl_true = target.data.cpu() # Function to save visualizations of the predicted label map for img, lt, lp in zip(imgs, lbl_true, lbl_pred): img = self.val_loader.dataset.untransform(img.numpy()) lt = lt.numpy() label_trues.append(lt) label_preds.append(lp) if len(visualizations) < num_vis: viz = fcn.utils.visualize_segmentation( lbl_pred=lp, lbl_true=lt, img=img, n_class=self.n_class) visualizations.append(viz) # Measuring the performance metrics = torchfcn.utils.label_accuracy_score(label_trues, label_preds, self.n_class) out = osp.join(self.out, 'visualization_viz') if not osp.exists(out): os.makedirs(out) out_file = osp.join(out, 'iter%012d.jpg' % self.iteration) scipy.misc.imsave(out_file, fcn.utils.get_tile_image(visualizations)) val_loss /= len(self.val_loader) # Logging with open(osp.join(self.out, 'log.csv'), 'a') as f: elapsed_time = \ datetime.datetime.now(pytz.timezone('Asia/Tokyo')) - \ self.timestamp_start log = [self.epoch, self.iteration] + [''] * 5 + \ [val_loss] + list(metrics) + [elapsed_time] log = map(str, log) f.write(','.join(log) + '\n') # Saving the model checkpoint mean_iu = metrics[2] is_best = mean_iu > self.best_mean_iu if is_best: self.best_mean_iu = mean_iu torch.save( { 'epoch': self.epoch, 'iteration': self.iteration, 'arch': self.model.__class__.__name__, 'optim_state_dict': self.optim.state_dict(), 'model_state_dict': self.model.state_dict(), 'best_mean_iu': self.best_mean_iu, }, osp.join(self.out, 'checkpoint.pth.tar')) if is_best: shutil.copy(osp.join(self.out, 'checkpoint.pth.tar'), osp.join(self.out, 'model_best.pth.tar'))
def train_epoch(self): """ Function to train the model for one epoch """ self.model.train() # Loop for training the model for batch_idx, datas in tqdm.tqdm(enumerate(self.train_loader), total=len(self.train_loader), desc='Train epoch=%d' % self.epoch, ncols=80, leave=False): batch_size = 1 iteration = batch_idx + self.epoch * len(self.train_loader) self.iteration = iteration if self.iteration % self.interval_validate == 0 and self.iteration > 0: self.validate() self.model.train() # Obtaining data in the right format data_source, labels_source = datas if self.cuda: data_source, labels_source = data_source.cuda( ), labels_source.cuda() data_source, labels_source = Variable(data_source), Variable( labels_source) # Forward pass self.optim.zero_grad() source_pred = self.model(data_source) # Computing the segmentation loss loss_seg = cross_entropy2d(source_pred, labels_source, size_average=self.size_average) loss_seg /= len(data_source) # Updating the model (backward pass) self.optim.zero_grad() loss_seg.backward() self.optim.step() if np.isnan(float(loss_seg.data[0])): raise ValueError('loss is nan while training') # Computing and logging performance metrics metrics = [] lbl_pred = source_pred.data.max(1)[1].cpu().numpy()[:, :, :] lbl_true = labels_source.data.cpu().numpy() for lt, lp in zip(lbl_true, lbl_pred): acc, acc_cls, mean_iu, fwavacc = \ torchfcn.utils.label_accuracy_score( [lt], [lp], n_class=self.n_class) metrics.append((acc, acc_cls, mean_iu, fwavacc)) metrics = np.mean(metrics, axis=0) with open(osp.join(self.out, 'log.csv'), 'a') as f: elapsed_time = ( datetime.datetime.now(pytz.timezone('Asia/Tokyo')) - self.timestamp_start).total_seconds() log = [self.epoch, self.iteration] + [loss_seg.data[0]] + \ metrics.tolist() + [''] * 5 + [elapsed_time] log = map(str, log) f.write(','.join(log) + '\n') if self.iteration >= self.max_iter: break
def train_epoch(self): """ Function to train the model for one epoch """ self.model.train() self.netG.train() self.netD.train() for batch_idx, (datas, datat) in tqdm.tqdm( enumerate(itertools.izip(self.train_loader, self.target_loader)), total=min(len(self.target_loader), len(self.train_loader)), desc='Train epoch = %d' % self.epoch, ncols=80, leave=False): data_source, labels_source = datas data_target, __ = datat data_source_forD = torch.zeros((data_source.size()[0], 3, self.image_size_forD[1], self.image_size_forD[0])) data_target_forD = torch.zeros((data_target.size()[0], 3, self.image_size_forD[1], self.image_size_forD[0])) # We pass the unnormalized data to the discriminator. So, the GANs produce images without data normalization for i in range(data_source.size()[0]): data_source_forD[i] = self.train_loader.dataset.transform_forD(data_source[i], self.image_size_forD, resize=False, mean_add=True) data_target_forD[i] = self.train_loader.dataset.transform_forD(data_target[i], self.image_size_forD, resize=False, mean_add=True) iteration = batch_idx + self.epoch * min(len(self.train_loader), len(self.target_loader)) self.iteration = iteration if self.cuda: data_source, labels_source = data_source.cuda(), labels_source.cuda() data_target = data_target.cuda() data_source_forD = data_source_forD.cuda() data_target_forD = data_target_forD.cuda() data_source, labels_source = Variable(data_source), Variable(labels_source) data_target = Variable(data_target) data_source_forD = Variable(data_source_forD) data_target_forD = Variable(data_target_forD) # Source domain score, fc7, pool4, pool3 = self.model(data_source) outG_src = self.netG(fc7, pool4, pool3) outD_src_fake_s, outD_src_fake_c = self.netD(outG_src) outD_src_real_s, outD_src_real_c = self.netD(data_source_forD) # target domain tscore, tfc7, tpool4, tpool3= self.model(data_target) outG_tgt = self.netG(tfc7, tpool4, tpool3) outD_tgt_real_s, outD_tgt_real_c = self.netD(data_target_forD) outD_tgt_fake_s, outD_tgt_fake_c = self.netD(outG_tgt) # Creating labels for D. We need two sets of labels since our model is a ACGAN style framework. # (1) Labels for the classsifier branch. This will be a downsampled version of original segmentation labels # (2) Domain lables for classifying source real, source fake, target real and target fake # Labels for classifier branch Dout_sz = outD_src_real_s.size() label_forD = torch.zeros((outD_tgt_fake_c.size()[0], outD_tgt_fake_c.size()[2], outD_tgt_fake_c.size()[3])) for i in range(label_forD.size()[0]): label_forD[i] = self.train_loader.dataset.transform_label_forD(labels_source[i], (outD_tgt_fake_c.size()[2], outD_tgt_fake_c.size()[3])) if self.cuda: label_forD = label_forD.cuda() label_forD = Variable(label_forD.long()) # Domain labels domain_labels_src_real = torch.LongTensor(Dout_sz[0],Dout_sz[2],Dout_sz[3]).zero_() domain_labels_src_fake = torch.LongTensor(Dout_sz[0],Dout_sz[2],Dout_sz[3]).zero_()+1 domain_labels_tgt_real = torch.LongTensor(Dout_sz[0],Dout_sz[2],Dout_sz[3]).zero_()+2 domain_labels_tgt_fake = torch.LongTensor(Dout_sz[0],Dout_sz[2],Dout_sz[3]).zero_()+3 domain_labels_src_real = Variable(domain_labels_src_real.cuda()) domain_labels_src_fake = Variable(domain_labels_src_fake.cuda()) domain_labels_tgt_real = Variable(domain_labels_tgt_real.cuda()) domain_labels_tgt_fake = Variable(domain_labels_tgt_fake.cuda()) # Updates. # There are three sets of updates - (1) Discriminator, (2) Generator and (3) F network # (1) Discriminator updates lossD_src_real_s = cross_entropy2d(outD_src_real_s, domain_labels_src_real, size_average=self.size_average) lossD_src_fake_s = cross_entropy2d(outD_src_fake_s, domain_labels_src_fake, size_average=self.size_average) lossD_src_real_c = cross_entropy2d(outD_src_real_c, label_forD, size_average=self.size_average) lossD_tgt_real = cross_entropy2d(outD_tgt_real_s, domain_labels_tgt_real, size_average=self.size_average) lossD_tgt_fake = cross_entropy2d(outD_tgt_fake_s, domain_labels_tgt_fake, size_average=self.size_average) self.optimD.zero_grad() lossD = lossD_src_real_s + lossD_src_fake_s + lossD_src_real_c + lossD_tgt_real + lossD_tgt_fake lossD /= len(data_source) lossD.backward(retain_graph=True) self.optimD.step() # (2) Generator updates self.optimG.zero_grad() lossG_src_adv_s = cross_entropy2d(outD_src_fake_s, domain_labels_src_real,size_average=self.size_average) lossG_src_adv_c = cross_entropy2d(outD_src_fake_c, label_forD,size_average=self.size_average) lossG_tgt_adv_s = cross_entropy2d(outD_tgt_fake_s, domain_labels_tgt_real,size_average=self.size_average) lossG_src_mse = F.l1_loss(outG_src,data_source_forD) lossG_tgt_mse = F.l1_loss(outG_tgt,data_target_forD) lossG = lossG_src_adv_c + 0.1*(lossG_src_adv_s+ lossG_tgt_adv_s) + self.l1_weight * (lossG_src_mse + lossG_tgt_mse) lossG /= len(data_source) lossG.backward(retain_graph=True) self.optimG.step() # (3) F network updates self.optim.zero_grad() lossC = cross_entropy2d(score, labels_source,size_average=self.size_average) lossF_src_adv_s = cross_entropy2d(outD_src_fake_s, domain_labels_tgt_real,size_average=self.size_average) lossF_tgt_adv_s = cross_entropy2d(outD_tgt_fake_s, domain_labels_src_real,size_average=self.size_average) lossF_src_adv_c = cross_entropy2d(outD_src_fake_c, label_forD,size_average=self.size_average) lossF = lossC + self.adv_weight*(lossF_src_adv_s + lossF_tgt_adv_s) + self.c_weight*lossF_src_adv_c lossF /= len(data_source) lossF.backward() self.optim.step() if np.isnan(float(lossD.data[0])): raise ValueError('lossD is nan while training') if np.isnan(float(lossG.data[0])): raise ValueError('lossG is nan while training') if np.isnan(float(lossF.data[0])): raise ValueError('lossF is nan while training') # Computing metrics for logging metrics = [] lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :] lbl_true = labels_source.data.cpu().numpy() for lt, lp in zip(lbl_true, lbl_pred): acc, acc_cls, mean_iu, fwavacc = \ torchfcn.utils.label_accuracy_score( [lt], [lp], n_class=self.n_class) metrics.append((acc, acc_cls, mean_iu, fwavacc)) metrics = np.mean(metrics, axis=0) # Logging with open(osp.join(self.out, 'log.csv'), 'a') as f: elapsed_time = ( datetime.datetime.now(pytz.timezone('Asia/Tokyo')) - self.timestamp_start).total_seconds() log = [self.epoch, self.iteration] + [lossF.data[0]] + \ metrics.tolist() + [''] * 5 + [elapsed_time] log = map(str, log) f.write(','.join(log) + '\n') if self.iteration >= self.max_iter: break # Validating periodically if self.iteration % self.interval_validate == 0 and self.iteration > 0: out_recon = osp.join(self.out, 'visualization_viz') if not osp.exists(out_recon): os.makedirs(out_recon) generations = [] # Saving generated source and target images source_img = self.val_loader.dataset.untransform(data_source.data.cpu().numpy().squeeze()) target_img = self.val_loader.dataset.untransform(data_target.data.cpu().numpy().squeeze()) outG_src_ = (outG_src)*255.0 outG_tgt_ = (outG_tgt)*255.0 outG_src_ = outG_src_.data.cpu().numpy().squeeze().transpose((1,2,0))[:,:,::-1].astype(np.uint8) outG_tgt_ = outG_tgt_.data.cpu().numpy().squeeze().transpose((1,2,0))[:,:,::-1].astype(np.uint8) generations.append(source_img) generations.append(outG_src_) generations.append(target_img) generations.append(outG_tgt_) out_file = osp.join(out_recon, 'iter%012d_src_target_recon.png' % self.iteration) scipy.misc.imsave(out_file, fcn.utils.get_tile_image(generations)) # Validation self.validate() self.model.train() self.netG.train()
def validate(self): """ Function to validate a training model on the val split. """ self.model.eval() self.netG.eval() val_loss = 0 num_vis = 8 visualizations = [] generations = [] label_trues, label_preds = [], [] # Evaluation for batch_idx, (data, target) in tqdm.tqdm( enumerate(self.val_loader), total=len(self.val_loader), desc='Validation iteration = %d' % self.iteration, ncols=80, leave=False): if self.cuda: data, target = data.cuda(), target.cuda() data, target = Variable(data, volatile=True), Variable(target) score, fc7, pool4, pool3 = self.model(data) outG = self.netG(fc7, pool4, pool3) loss = cross_entropy2d(score, target, size_average=self.size_average) if np.isnan(float(loss.data[0])): raise ValueError('loss is nan while validating') val_loss += float(loss.data[0]) / len(data) imgs = data.data.cpu() lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :] lbl_true = target.data.cpu() # Visualizing predicted labels for img, lt, lp , outG_ in zip(imgs, lbl_true, lbl_pred,outG): outG_ = outG_*255.0 outG_ = outG_.data.cpu().numpy().squeeze().transpose((1,2,0))[:,:,::-1].astype(np.uint8) img = self.val_loader.dataset.untransform(img.numpy()) lt = lt.numpy() label_trues.append(lt) label_preds.append(lp) if len(visualizations) < num_vis: viz = fcn.utils.visualize_segmentation( lbl_pred=lp, lbl_true=lt, img=img, n_class=self.n_class) visualizations.append(viz) generations.append(outG_) # Computing the metrics metrics = torchfcn.utils.label_accuracy_score( label_trues, label_preds, self.n_class) val_loss /= len(self.val_loader) # Saving the label visualizations and generations out = osp.join(self.out, 'visualization_viz') if not osp.exists(out): os.makedirs(out) out_file = osp.join(out, 'iter%012d_labelmap.jpg' % self.iteration) scipy.misc.imsave(out_file, fcn.utils.get_tile_image(visualizations)) out_file = osp.join(out, 'iter%012d_generations.jpg' % self.iteration) scipy.misc.imsave(out_file, fcn.utils.get_tile_image(generations)) # Logging with open(osp.join(self.out, 'log.csv'), 'a') as f: elapsed_time = \ datetime.datetime.now(pytz.timezone('Asia/Tokyo')) - \ self.timestamp_start log = [self.epoch, self.iteration] + [''] * 5 + \ [val_loss] + list(metrics) + [elapsed_time] log = map(str, log) f.write(','.join(log) + '\n') # Saving the models mean_iu = metrics[2] is_best = mean_iu > self.best_mean_iu if is_best: self.best_mean_iu = mean_iu torch.save({ 'epoch': self.epoch, 'iteration': self.iteration, 'arch': self.model.__class__.__name__, 'optim_state_dict': self.optim.state_dict(), 'model_state_dict': self.model.state_dict(), 'best_mean_iu': self.best_mean_iu, }, osp.join(self.out, 'checkpoint.pth.tar')) if is_best: shutil.copy(osp.join(self.out, 'checkpoint.pth.tar'), osp.join(self.out, 'model_best.pth.tar'))
def train_epoch(self): """ Function to train the model for one epoch """ self.model.train() # Loop for training the model for batch_idx, datas in tqdm.tqdm( enumerate(self.train_loader), total= len(self.train_loader), desc='Train epoch=%d' % self.epoch, ncols=80, leave=False): batch_size = 1 iteration = batch_idx + self.epoch * len(self.train_loader) self.iteration = iteration if self.iteration % self.interval_validate == 0 and self.iteration>0: self.validate() self.model.train() # Obtaining data in the right format data_source, labels_source = datas if self.cuda: data_source, labels_source = data_source.cuda(), labels_source.cuda() data_source, labels_source = Variable(data_source), Variable(labels_source) # Forward pass self.optim.zero_grad() source_pred = self.model(data_source) # Computing the segmentation loss loss_seg = cross_entropy2d(source_pred, labels_source, size_average=self.size_average) loss_seg /= len(data_source) # Updating the model (backward pass) self.optim.zero_grad() loss_seg.backward() self.optim.step() if np.isnan(float(loss_seg.data[0])): raise ValueError('loss is nan while training') # Computing and logging performance metrics metrics = [] lbl_pred = source_pred.data.max(1)[1].cpu().numpy()[:, :, :] lbl_true = labels_source.data.cpu().numpy() for lt, lp in zip(lbl_true, lbl_pred): acc, acc_cls, mean_iu, fwavacc = \ torchfcn.utils.label_accuracy_score( [lt], [lp], n_class=self.n_class) metrics.append((acc, acc_cls, mean_iu, fwavacc)) metrics = np.mean(metrics, axis=0) with open(osp.join(self.out, 'log.csv'), 'a') as f: elapsed_time = ( datetime.datetime.now(pytz.timezone('Asia/Tokyo')) - self.timestamp_start).total_seconds() log = [self.epoch, self.iteration] + [loss_seg.data[0]] + \ metrics.tolist() + [''] * 5 + [elapsed_time] log = map(str, log) f.write(','.join(log) + '\n') if self.iteration >= self.max_iter: break
def test(self): training = self.model.training self.model.eval() n_class = len(self.test_loader.dataset.class_names) test_loss = 0 visualizations = [] label_trues, label_preds = [], [] for batch_idx, (data, target) in tqdm.tqdm(enumerate(self.test_loader), total=len(self.test_loader), desc='Test iteration=%d' % self.iteration, ncols=80, leave=False): if self.cuda: data, target = data.cuda(), target.cuda() data, target = Variable(data, volatile=True), Variable(target) score = self.model(data) loss = utils.cross_entropy2d(score, target, size_average=self.size_average) loss_data = float(loss.data[0]) if np.isnan(loss_data): raise ValueError('loss is nan while testing') test_loss += loss_data / len(data) imgs = data.data.cpu() lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :].astype( np.uint8) lbl_true = target.data.cpu().numpy() for img, lt, lp in zip(imgs, lbl_true, lbl_pred): img, lt = self.test_loader.dataset.untransform(img, lt) label_trues.append(lt) label_preds.append(lp) if len(visualizations) < 9: viz = fcn.utils.visualize_segmentation(lbl_pred=lp, lbl_true=lt, img=img, n_class=n_class) visualizations.append(viz) metrics = utils.label_accuracy_score(label_trues, label_preds, n_class) out = osp.join(self.out, 'visualization_viz') if not osp.exists(out): os.makedirs(out) out_file = osp.join(out, 'iter_test_%012d.jpg' % self.iteration) scipy.misc.imsave(out_file, fcn.utils.get_tile_image(visualizations)) test_loss /= len(self.test_loader) with open(osp.join(self.out, 'log.csv'), 'a') as f: elapsed_time = ( datetime.datetime.now(pytz.timezone('Asia/Tokyo')) - self.timestamp_start).total_seconds() log = [self.epoch, self.iteration ] + [''] * 5 + [test_loss] + list(metrics) + [elapsed_time] log = map(str, log) f.write(','.join(log) + '\n') # logging information for tensorboard info = OrderedDict({ "loss": test_loss, "acc": metrics[0], "acc_cls": metrics[1], "meanIoU": metrics[2], "fwavacc": metrics[3], "bestIoU": self.best_mean_iu, }) len(self.train_loader) # msg = "\t".join([key + ":" + "%.4f" % value for key, value in info.items()]) partial_epoch = self.iteration / len(self.train_loader) for tag, value in info.items(): self.ts_logger.scalar_summary(tag, value, partial_epoch) if training: self.model.train()
def train_epoch(self): self.model.train() n_class = len(self.train_loader.dataset.class_names) for batch_idx, (data, target) in tqdm.tqdm( enumerate(self.train_loader), total=len(self.train_loader), desc='Train epoch=%d' % self.epoch, ncols=80, leave=False): iteration = batch_idx + self.epoch * len(self.train_loader) if self.iteration != 0 and (iteration - 1) != self.iteration: continue # for resuming self.iteration = iteration if self.iteration % self.interval_validate == 0: self.validate() self.test() assert self.model.training if self.cuda: data, target = data.cuda(), target.cuda() data, target = Variable(data), Variable(target) self.optim.zero_grad() score = self.model(data) weights = torch.from_numpy( self.train_loader.dataset.class_weights).float().cuda() ignore = self.train_loader.dataset.class_ignore loss = utils.cross_entropy2d(score, target, weight=weights, size_average=self.size_average, ignore=ignore) loss /= len(data) loss_data = float(loss.data[0]) if np.isnan(loss_data): raise ValueError('loss is nan while training') loss.backward() self.optim.step() metrics = [] lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :] lbl_true = target.data.cpu().numpy() acc, acc_cls, mean_iu, fwavacc = utils.label_accuracy_score( lbl_true, lbl_pred, n_class=n_class) metrics.append((acc, acc_cls, mean_iu, fwavacc)) metrics = np.mean(metrics, axis=0) with open(osp.join(self.out, 'log.csv'), 'a') as f: elapsed_time = ( datetime.datetime.now(pytz.timezone('Asia/Tokyo')) - self.timestamp_start).total_seconds() log = [self.epoch, self.iteration] + [loss_data] + \ metrics.tolist() + [''] * 5 + [elapsed_time] log = map(str, log) f.write(','.join(log) + '\n') # logging to tensorboard self.best_train_meanIoU = max(self.best_train_meanIoU, metrics[2]) info = OrderedDict({ "loss": loss.data[0], "acc": metrics[0], "acc_cls": metrics[1], "meanIoU": metrics[2], "fwavacc": metrics[3], "bestIoU": self.best_train_meanIoU, }) partialEpoch = self.epoch + float(batch_idx) / len( self.train_loader) for tag, value in info.items(): self.t_logger.scalar_summary(tag, value, partialEpoch) if self.iteration >= self.max_iter: break
def main(): ######### configs ########### best_metric = 0 pretrain_vgg16_path = os.path.join(configs.py_dir, 'model/vgg16_from_caffe.pth') ###### load datasets ######## train_data = voc_dates.VOCDataset(configs.train_img_dir, configs.train_label_dir, configs.train_txt_dir, 'train', transform=True) train_loader = Data.DataLoader(train_data, batch_size=configs.batch_size, shuffle=True, num_workers=4, pin_memory=True) val_data = voc_dates.VOCDataset(configs.val_img_dir, configs.val_label_dir, configs.val_txt_dir, 'val', transform=True) val_loader = Data.DataLoader(val_data, batch_size=configs.batch_size, shuffle=False, num_workers=4, pin_memory=True) ###### build models ######## fcn32s = models.fcn32s() vgg_pretrain_model = utils.load_pretrain_model(pretrain_vgg16_path) fcn32s.init_parameters(vgg_pretrain_model) fcn32s = fcn32s.cuda() ######### if resume: checkpoint = torch.load(configs.best_ckpt_dir) fcn32s.load_state_dict(checkpoint['state_dict']) print('resum sucess') ######### optimizer ########## ######## how to set different learning rate for differern layer ######### optimizer = torch.optim.SGD([ { 'params': get_parameters(fcn32s, bias=False) }, { 'params': get_parameters(fcn32s, bias=True), 'lr': configs.learning_rate * 2, 'weight_decay': 0 }, ], lr=configs.learning_rate, momentum=configs.momentum, weight_decay=configs.weight_decay) ######## iter img_label pairs ########### for epoch in range(20): utils.adjust_learning_rate(configs.learning_rate, optimizer, epoch) for batch_idx, (img_idx, label_idx) in enumerate(train_loader): img, label = Variable(img_idx.cuda()), Variable(label_idx.cuda()) prediction = fcn32s(img) loss = utils.cross_entropy2d(prediction, label, size_average=False) optimizer.zero_grad() loss.backward() optimizer.step() if (batch_idx) % 20 == 0: print("Epoch [%d/%d] Loss: %.4f" % (epoch, batch_idx, loss.data[0])) current_metric = validate(fcn32s, val_loader, epoch) if current_metric > best_metric: torch.save({'state_dict': fcn32s.state_dict()}, os.path.join(configs.save_ckpt_dir, 'fcn32s' + str(epoch) + '.pth')) shutil.copy( os.path.join(configs.save_ckpt_dir, 'fcn32s' + str(epoch) + '.pth'), os.path.join(configs.save_ckpt_dir, 'model_best.pth')) best_metric = current_metric if epoch % 5 == 0: torch.save({'state_dict': fcn32s.state_dict()}, os.path.join(configs.save_ckpt_dir, 'fcn32s' + str(epoch) + '.pth'))
# ], lr=learning_rate, momentum=momentum, weight_decay=weight_decay) optimizer = torch.optim.SGD(model.parameters(), lr=alearning_rate, momentum=momentum, weight_decay=weight_decay) scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[ int(0.4 * end_epoch), int(0.7 * end_epoch), int(0.8 * end_epoch), int(0.9 * end_epoch) ], gamma=0.1) # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',patience=10, verbose=True) criterion = cross_entropy2d() # resume if (os.path.isfile(resume_path) and resume_flag): checkpoint = torch.load(resume_path) model.load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) best_iou = checkpoint['best_iou'] # scheduler.load_state_dict(checkpoint["scheduler_state"]) # start_epoch = checkpoint["epoch"] print( "=====>", "Loaded checkpoint '{}' (iter {})".format(resume_path, checkpoint["epoch"])) else: print("=====>", "No checkpoint found at '{}'".format(resume_path))