def train_epoch(self): """ Function to train the model for one epoch """ def set_requires_grad(seg, dis): for param in self.model.parameters(): param.requires_grad = seg for param in self.netD.parameters(): param.requires_grad = dis for batch_idx, (datas, datat) in tqdm.tqdm( enumerate(itertools.izip(self.train_loader, self.target_loader)), total=self.iters_per_epoch, desc='Train epoch = {}/{}'.format(self.epoch, self.max_epoch)): self.iteration = batch_idx + self.epoch * self.iters_per_epoch source_data, source_labels = datas target_data, __ = datat self.optim.zero_grad() self.optimD.zero_grad() src_dis_label = 1 target_dis_label = 0 if self.cuda: source_data, source_labels = source_data.cuda( ), source_labels.cuda() target_data = target_data.cuda() source_data, source_labels = Variable(source_data), Variable( source_labels) target_data = Variable(target_data) ############train G, item1 #set_requires_grad(seg=True, dis=False) # Source domain score = self.model(source_data) l_seg = CrossEntropyLoss2d_Seg(score, source_labels, class_num=class_num, size_average=self.size_average) # target domain seg_target_score = self.model(target_data) modelfix_target_score = self.model_fix(target_data) diff2d = Diff2d() distill_loss = diff2d(seg_target_score, modelfix_target_score) seg_loss = l_seg + 10 * distill_loss #seg_loss.backward(retain_graph=True) #######train G, item 2 """ bce_loss = torch.nn.BCEWithLogitsLoss() src_discriminate_result = self.netD(score) target_discriminate_result = self.netD(seg_target_score) src_dis_loss = bce_loss(src_discriminate_result, Variable(torch.FloatTensor(src_discriminate_result.data.size()).fill_( src_dis_label)).cuda()) target_dis_loss = bce_loss(target_discriminate_result, Variable( torch.FloatTensor(target_discriminate_result.data.size()).fill_( target_dis_label)).cuda(), ) dis_loss = src_dis_loss + target_dis_loss dis_loss.backward(retain_graph=True) """ #######################train D #set_requires_grad(seg=False, dis=True) bce_loss = torch.nn.BCEWithLogitsLoss() src_discriminate_result = self.netD(score.detach()) target_discriminate_result = self.netD(seg_target_score.detach()) src_dis_loss = bce_loss( src_discriminate_result, Variable( torch.FloatTensor( src_discriminate_result.data.size()).fill_( src_dis_label)).cuda()) target_dis_loss = bce_loss( target_discriminate_result, Variable( torch.FloatTensor( target_discriminate_result.data.size()).fill_( target_dis_label)).cuda(), ) dis_loss = src_dis_loss + target_dis_loss # this loss has been inversed!! total_loss = dis_loss + seg_loss total_loss.backward() self.optim.step() self.optimD.step() if np.isnan(float(dis_loss.data[0])): raise ValueError('dis_loss is nan while training') if np.isnan(float(seg_loss.data[0])): raise ValueError('total_loss is nan while training') if self.iteration % self.loss_print_interval == 0: logger.info( "L_SEG={}, Distill_LOSS={}, Discriminater loss={}".format( l_seg.data[0], distill_loss.data[0], dis_loss.data[0]))
def train_epoch(self): """ Function to train the model for one epoch """ def set_requires_grad(seg, dis): for param in self.model.parameters(): param.requires_grad = seg for param in self.netD.parameters(): param.requires_grad = dis import copy self.G_source_loader_iter = [ enumerate(self.train_loader) for _ in range(G_STEP) ] self.G_target_loader_iter = [ enumerate(self.target_loader) for _ in range(G_STEP) ] self.D_source_loader_iter = [ enumerate(self.train_loader) for _ in range(D_STEP) ] self.D_target_loader_iter = [ enumerate(self.target_loader) for _ in range(D_STEP) ] for batch_idx in tqdm.tqdm(range(self.iters_per_epoch), total=self.iters_per_epoch, desc='Train epoch = {}/{}'.format( self.epoch, self.max_epoch)): self.iteration = batch_idx + self.epoch * self.iters_per_epoch src_dis_label = 1 target_dis_label = 0 mse_loss = torch.nn.MSELoss() def get_data(source_iter, target_iter): _, source_batch = source_iter.next() source_data, source_labels = source_batch _, target_batch = target_iter.next() target_data, _ = target_batch if self.cuda: source_data, source_labels = source_data.cuda( ), source_labels.cuda() target_data = target_data.cuda() source_data, source_labels = Variable(source_data), Variable( source_labels) target_data = Variable(target_data) return source_data, source_labels, target_data ##################################train D for _ in range(D_STEP): source_data, source_labels, target_data = get_data( self.D_source_loader_iter[_], self.D_target_loader_iter[_]) self.optimD.zero_grad() set_requires_grad(seg=False, dis=True) score = self.model(source_data) seg_target_score = self.model(target_data) src_discriminate_result = self.netD(F.softmax(score)) target_discriminate_result = self.netD( F.softmax(seg_target_score)) src_dis_loss = mse_loss( src_discriminate_result, Variable( torch.FloatTensor( src_discriminate_result.data.size()).fill_( src_dis_label)).cuda()) target_dis_loss = mse_loss( target_discriminate_result, Variable( torch.FloatTensor( target_discriminate_result.data.size()).fill_( target_dis_label)).cuda(), ) src_dis_loss = src_dis_loss * DIS_WEIGHT target_dis_loss = target_dis_loss * DIS_WEIGHT dis_loss = src_dis_loss + target_dis_loss dis_loss.backward() self.optimD.step() # https://ewanlee.github.io/2017/04/29/WGAN-implemented-by-PyTorch/ for p in self.netD.parameters(): p.data.clamp_(-0.01, 0.01) #####################train G, item1 for _ in range(G_STEP): source_data, source_labels, target_data = get_data( self.G_source_loader_iter[_], self.G_target_loader_iter[_]) self.optim.zero_grad() set_requires_grad(seg=True, dis=False) # Source domain score = self.model(source_data) l_seg = CrossEntropyLoss2d_Seg(score, source_labels, class_num=class_num, size_average=self.size_average) # target domain seg_target_score = self.model(target_data) modelfix_target_score = self.model_fix(target_data) diff2d = Diff2d() distill_loss = diff2d(seg_target_score, modelfix_target_score) l_seg = l_seg * L_LOSS_WEIGHT distill_loss = distill_loss * DISTILL_WEIGHT seg_loss = l_seg + distill_loss #######train G, item 2 src_discriminate_result = self.netD(F.softmax(score)) target_discriminate_result = self.netD( F.softmax(seg_target_score)) src_dis_loss = mse_loss( src_discriminate_result, Variable( torch.FloatTensor( src_discriminate_result.data.size()).fill_( src_dis_label)).cuda()) target_dis_loss = mse_loss( target_discriminate_result, Variable( torch.FloatTensor( target_discriminate_result.data.size()).fill_( src_dis_label)).cuda(), ) src_dis_loss = src_dis_loss * DIS_WEIGHT target_dis_loss = target_dis_loss * DIS_WEIGHT dis_loss = src_dis_loss + target_dis_loss total_loss = seg_loss + dis_loss total_loss.backward() self.optim.step() if np.isnan(float(dis_loss.data[0])): raise ValueError('dis_loss is nan while training') if np.isnan(float(seg_loss.data[0])): raise ValueError('total_loss is nan while training') if self.iteration % self.loss_print_interval == 0: logger.info( "After weight Loss: seg_Loss={}, distill_LOSS={}, src_dis_loss={}, target_dis_loss={}" .format(l_seg.data[0], distill_loss.data[0], src_dis_loss.data[0], target_dis_loss.data[0]))
def train_epoch(self): """ Function to train the model for one epoch """ def set_requires_grad(seg, dis): for param in self.model.parameters(): param.requires_grad = seg for param in self.netD.parameters(): param.requires_grad = dis for batch_idx, (datas, datat) in tqdm.tqdm( enumerate(itertools.izip(self.train_loader, self.target_loader)), total=self.iters_per_epoch, desc='Train epoch = {}/{}'.format(self.epoch, self.max_epoch)): self.iteration = batch_idx + self.epoch * self.iters_per_epoch source_data, source_labels = datas target_data, __ = datat src_dis_label = 1 target_dis_label = 0 bce_loss = torch.nn.BCEWithLogitsLoss() if self.cuda: source_data, source_labels = source_data.cuda( ), source_labels.cuda() target_data = target_data.cuda() source_data, source_labels = Variable(source_data), Variable( source_labels) target_data = Variable(target_data) #####################train G, item1 self.optim.zero_grad() set_requires_grad(seg=True, dis=False) # Source domain score = self.model(source_data) l_seg = CrossEntropyLoss2d_Seg(score, source_labels, class_num=class_num, size_average=self.size_average) # target domain seg_target_score = self.model(target_data) modelfix_target_score = self.model_fix(target_data) diff2d = Diff2d() distill_loss = diff2d(seg_target_score, modelfix_target_score) l_seg = l_seg * L_LOSS_WEIGHT distill_loss = distill_loss * DISTILL_WEIGHT seg_loss = l_seg + distill_loss #######train G, item 2 src_discriminate_result = self.netD(F.softmax(score)) target_discriminate_result = self.netD(F.softmax(seg_target_score)) src_dis_loss = bce_loss( src_discriminate_result, Variable( torch.FloatTensor( src_discriminate_result.data.size()).fill_( src_dis_label)).cuda()) target_dis_loss = bce_loss( target_discriminate_result, Variable( torch.FloatTensor( target_discriminate_result.data.size()).fill_( src_dis_label)).cuda(), ) src_dis_loss = src_dis_loss * DIS_WEIGHT target_dis_loss = target_dis_loss * DIS_WEIGHT dis_loss = src_dis_loss + target_dis_loss total_loss = seg_loss + dis_loss total_loss.backward() self.optim.step() ##################################train D for _ in range(DIS_TIMES): self.optimD.zero_grad() set_requires_grad(seg=False, dis=True) score = self.model(source_data) seg_target_score = self.model(target_data) src_discriminate_result = self.netD(F.softmax(score)) target_discriminate_result = self.netD( F.softmax(seg_target_score)) src_dis_loss = bce_loss( src_discriminate_result, Variable( torch.FloatTensor( src_discriminate_result.data.size()).fill_( src_dis_label)).cuda()) target_dis_loss = bce_loss( target_discriminate_result, Variable( torch.FloatTensor( target_discriminate_result.data.size()).fill_( target_dis_label)).cuda(), ) src_dis_loss = src_dis_loss * DIS_WEIGHT target_dis_loss = target_dis_loss * DIS_WEIGHT dis_loss = src_dis_loss + target_dis_loss dis_loss.backward() self.optimD.step() if np.isnan(float(dis_loss.data[0])): raise ValueError('dis_loss is nan while training') if np.isnan(float(seg_loss.data[0])): raise ValueError('total_loss is nan while training') if self.iteration % self.loss_print_interval == 0 or ( self.epoch == 0 and self.iteration < self.loss_print_interval): logger.info( "After weight Loss: seg_Loss={}, distill_LOSS={}, src_dis_loss={}, target_dis_loss={}" .format(l_seg.data[0], distill_loss.data[0], src_dis_loss.data[0], target_dis_loss.data[0]))
def train_epoch(self): """ Function to train the model for one epoch """ self.model.train() self.netD.train() for batch_idx, (datas, datat) in tqdm.tqdm( enumerate(itertools.izip(self.train_loader, self.target_loader)), total=self.iters_per_epoch, desc='Train epoch = {}/{}'.format(self.epoch, self.max_epoch), leave=False): source_data, source_labels = datas target_data, __ = datat self.iteration = batch_idx + self.epoch * self.iters_per_epoch if self.cuda: source_data, source_labels = source_data.cuda( ), source_labels.cuda() target_data = target_data.cuda() source_data, source_labels = Variable(source_data), Variable( source_labels) target_data = Variable(target_data) # TODO,split to 3x3 # Source domain score = self.model(source_data) l_seg = CrossEntropyLoss2d_Seg(score, source_labels, size_average=self.size_average) src_discriminate_result = self.netD(score) # target domain seg_target_score = self.model(target_data) modelfix_target_score = self.model_fix(target_data) target_discriminate_result = self.netD(seg_target_score) diff2d = Diff2d() distill_loss = diff2d(seg_target_score, modelfix_target_score) bce_loss = torch.nn.BCEWithLogitsLoss() src_dis_loss = bce_loss( src_discriminate_result, Variable( torch.FloatTensor( src_discriminate_result.data.size()).fill_(1)).cuda()) target_dis_loss = bce_loss( target_discriminate_result, Variable( torch.FloatTensor( target_discriminate_result.data.size()).fill_( 0)).cuda(), ) dis_loss = src_dis_loss + target_dis_loss # this loss has been inversed!! total_loss = l_seg + 10 * distill_loss + dis_loss self.optim.zero_grad() self.optimD.zero_grad() total_loss.backward() self.optim.step() self.optimD.step() if np.isnan(float(dis_loss.data[0])): raise ValueError('dis_loss is nan while training') if np.isnan(float(total_loss.data[0])): raise ValueError('total_loss is nan while training') if self.iteration % self.loss_print_interval == 0: logger.info( "L_SEG={}, Distill_LOSS={}, Discriminater loss={}, TOTAL_LOSS={}" .format(l_seg.data[0], distill_loss.data[0], dis_loss.data[0], total_loss.data[0])) # TODO, spatial loss if self.iteration >= self.max_iter: break # Validating periodically if self.iteration % self.interval_validate == 0 and self.iteration > 0: self.model.eval() self.validate() self.model.train() # return to training mode
def train_epoch(self): """ Function to train the model for one epoch """ def set_requires_grad(seg, dis): for param in self.model.parameters(): param.requires_grad = seg for param in self.netD.parameters(): param.requires_grad = dis self.train_loader_iter = enumerate(self.train_loader) self.target_loader_iter = enumerate(self.target_loader) for batch_idx in tqdm.tqdm( range(self.iters_per_epoch), total=self.iters_per_epoch, desc='Train epoch = {}/{}'.format(self.epoch, self.max_epoch)): self.iteration = batch_idx + self.epoch * self.iters_per_epoch self.optim.zero_grad() self.optimD.zero_grad() sum_l_seg = 0 sum_distill_loss =0 sum_src_dis_loss = 0 sum_target_dis_loss = 0 #for sub_iter in range(self.iter_size): _, source_batch = self.train_loader_iter.next() source_data, source_labels = source_batch _, target_batch = self.target_loader_iter.next() target_data, _ = target_batch if self.cuda: source_data, source_labels = source_data.cuda(), source_labels.cuda() target_data = target_data.cuda() source_data, source_labels = Variable(source_data), Variable(source_labels) target_data = Variable(target_data) #######################train G set_requires_grad(seg=True, dis=False) # Source domain score = self.model(source_data) l_seg = CrossEntropyLoss2d_Seg(score, source_labels, class_num=class_num, size_average=self.size_average) sum_l_seg += l_seg.data[0] # target domain seg_target_score = self.model(target_data) modelfix_target_score = self.model_fix(target_data) diff2d = Diff2d() distill_loss = diff2d(seg_target_score, modelfix_target_score) sum_distill_loss += distill_loss.data[0] seg_loss = l_seg + 10 * distill_loss seg_loss.backward() #########################train D set_requires_grad(seg=False, dis=True) src_discriminate_result = self.netD(score.detach()) target_discriminate_result = self.netD(seg_target_score.detach()) bce_loss = torch.nn.BCEWithLogitsLoss() src_dis_loss = bce_loss(src_discriminate_result, Variable(torch.FloatTensor(src_discriminate_result.data.size()).fill_(1)).cuda()) target_dis_loss = bce_loss(target_discriminate_result, Variable(torch.FloatTensor(target_discriminate_result.data.size()).fill_(0)).cuda()) sum_src_dis_loss += src_dis_loss.data[0] sum_target_dis_loss += target_dis_loss.data[0] dis_loss = src_dis_loss + target_dis_loss# this loss has been inversed!! dis_loss.backward() if np.isnan(float(dis_loss.data[0])): raise ValueError('dis_loss is nan while training') if np.isnan(float(seg_loss.data[0])): raise ValueError('total_loss is nan while training') self.optim.step() self.optimD.step() if self.iteration % self.loss_print_interval == 0: logger.info("L_SEG={}, Distill_LOSS={}, src dis loss={}, target dis loss={}".format(sum_l_seg, sum_distill_loss, sum_src_dis_loss,sum_target_dis_loss))