def train(self, src_data_loader, tgt_data_loader): self.model.train() self.seg_disc.train() # self.det_disc.train() self.mode = 'train' self.data_loader = src_data_loader tgt_data_iter = iter(tgt_data_loader) assert len(tgt_data_loader) <= len(src_data_loader) self._max_iters = self._max_epochs * len(self.data_loader) self.call_hook('before_train_epoch') time.sleep(2) # Prevent possible deadlock during epoch transition for i, src_data_batch in enumerate(self.data_loader): # before train iter self._inner_iter = i self.call_hook('before_train_iter') # fetch target data batch tgt_data_batch = next(tgt_data_iter, None) if tgt_data_batch is None: tgt_data_iter = iter(tgt_data_loader) tgt_data_batch = next(tgt_data_iter) # ------------------------ # train Discriminators # ------------------------ set_requires_grad(self.seg_disc, requires_grad=True) # src_img_feats=(4, 64, 225, 400) src_img_feats = self.model.extract_img_feat(**src_data_batch) tgt_img_feats = self.model.extract_img_feat(**tgt_data_batch) src_Dlogits = self.seg_disc(src_img_feats) src_Dloss = self.seg_disc.loss(src_Dlogits, src=True) log_src_Dloss = src_Dloss.item() tgt_Dlogits = self.seg_disc(tgt_img_feats) tgt_Dloss = self.seg_disc.loss(tgt_Dlogits, src=False) log_tgt_Dloss = tgt_Dloss.item() Dloss = (src_Dloss + tgt_Dloss) * 0.5 self.seg_opt.zero_grad() Dloss.backward() self.seg_opt.step() # ------------------------ # network forward on source: src_task_loss + lambda * tgt_GANLoss # ------------------------ set_requires_grad(self.seg_disc, requires_grad=False) losses, src_img_feats = self.model( **src_data_batch) # forward; losses: {'seg_loss'=} src_Dlogits = self.seg_disc(src_img_feats) # (N, 64, 225, 400) src_Dpred = src_Dlogits.max(1)[1] # (N, 225, 400); cuda src_Dlabels = torch.ones_like(src_Dpred, dtype=torch.long).cuda() src_acc = (src_Dpred == src_Dlabels).float().mean() if src_acc > self.src_acc_threshold: losses[ 'src_GANloss'] = self.lambda_GANLoss * self.seg_disc.loss( src_Dlogits, src=False) # ------------------------ # network forward on target # ------------------------ tgt_img_feats = self.model.extract_img_feat(**tgt_data_batch) tgt_Dlogits = self.seg_disc(tgt_img_feats) tgt_Dpred = tgt_Dlogits.max(1)[1] tgt_Dlabels = torch.zeros_like(tgt_Dpred, dtype=torch.long).cuda() tgt_acc = (tgt_Dpred == tgt_Dlabels).float().mean() if tgt_acc > self.tgt_acc_threshold: losses[ 'tgt_GANloss'] = self.lambda_GANLoss * self.seg_disc.loss( tgt_Dlogits, src=True) loss, log_vars = parse_losses(losses) num_samples = len(src_data_batch['img_metas']) log_vars['src_Dloss'] = log_src_Dloss log_vars['tgt_Dloss'] = log_tgt_Dloss log_vars['src_acc'] = src_acc.item() log_vars['tgt_acc'] = tgt_acc.item() self.optimizer.zero_grad() loss.backward() self.optimizer.step() # after_train_iter callback self.log_buffer.update(log_vars, num_samples) self.call_hook('after_train_iter') # optimizer hook && logger hook self._iter += 1 self.call_hook('after_train_epoch') self._epoch += 1
def train(self, src_data_loader, tgt_data_loader): self.model.train() self.seg_disc.train() self.det_disc.train() self.mode = 'train' self.data_loader = src_data_loader tgt_data_iter = iter(tgt_data_loader) assert len(tgt_data_loader) <= len(src_data_loader) self._max_iters = self._max_epochs * len(self.data_loader) self.call_hook('before_train_epoch') time.sleep(2) # Prevent possible deadlock during epoch transition for i, src_data_batch in enumerate(self.data_loader): # before train iter self._inner_iter = i self.call_hook('before_train_iter') # fetch target data batch tgt_data_batch = next(tgt_data_iter, None) if tgt_data_batch is None: tgt_data_iter = iter(tgt_data_loader) tgt_data_batch = next(tgt_data_iter) # ------------------------ # train Discriminators # ------------------------ set_requires_grad([self.seg_disc, self.det_disc], requires_grad=True) # seg_src/tgt_feats: (N, 128); det_src/tgt_feats: (4, 128, 200, 400) seg_src_feats, det_src_feats = self.model.forward_fusion( **src_data_batch) # src segmentation seg_src_logits = self.seg_disc(seg_src_feats) seg_src_Dloss = self.seg_disc.loss(seg_src_logits, src=True) log_seg_src_Dloss = seg_src_Dloss.item() # src detection det_src_logits = self.det_disc(det_src_feats) det_src_Dloss = self.det_disc.loss(det_src_logits, src=True) log_det_src_Dloss = det_src_Dloss.item() src_Dloss = seg_src_Dloss + det_src_Dloss self.seg_opt.zero_grad() self.det_opt.zero_grad() src_Dloss.backward() self.seg_opt.step() self.det_opt.step() # tgt segmentation seg_tgt_feats, det_tgt_feats = self.model.forward_fusion( **tgt_data_batch) seg_tgt_logits = self.seg_disc(seg_tgt_feats) seg_tgt_Dloss = self.seg_disc.loss(seg_tgt_logits, src=False) log_seg_tgt_Dloss = seg_tgt_Dloss.item() # tgt detection det_tgt_logits = self.det_disc(det_tgt_feats) det_tgt_Dloss = self.det_disc.loss(det_tgt_logits, src=False) log_det_tgt_Dloss = det_tgt_Dloss.item() tgt_Dloss = seg_tgt_Dloss + det_tgt_Dloss self.seg_opt.zero_grad() self.det_opt.zero_grad() tgt_Dloss.backward() self.seg_opt.step() self.det_opt.step() # ------------------------ # train network on source: task loss + GANLoss # ------------------------ set_requires_grad([self.seg_disc, self.det_disc], requires_grad=False) losses, seg_src_feats, det_src_feats = self.model( **src_data_batch) # forward; losses: {'seg_loss'=} seg_disc_logits = self.seg_disc(seg_src_feats) # (N, 2) seg_disc_pred = seg_disc_logits.max(1)[1] # (N, ); cuda seg_label = torch.ones_like(seg_disc_pred, dtype=torch.long).cuda() seg_acc = (seg_disc_pred == seg_label).float().mean() acc_threshold = 0.6 if seg_acc > acc_threshold: losses[ 'seg_src_GANloss'] = self.lambda_GANLoss * self.seg_disc.loss( seg_disc_logits, src=False) det_disc_logits = self.det_disc(det_src_feats) # (4, 2, 49, 99) det_disc_pred = det_disc_logits.max(1)[1] # (4, 49, 99); cuda det_label = torch.ones_like(det_disc_pred, dtype=torch.long).cuda() det_acc = (det_disc_pred == det_label).float().mean() if det_acc > acc_threshold: losses[ 'det_src_GANloss'] = self.lambda_GANLoss * self.det_disc.loss( det_disc_logits, src=False) loss, log_vars = parse_losses(losses) num_samples = len(src_data_batch['img_metas']) log_vars['seg_src_Dloss'] = log_seg_src_Dloss log_vars['det_src_Dloss'] = log_det_src_Dloss log_vars['seg_tgt_Dloss'] = log_seg_tgt_Dloss log_vars['det_tgt_Dloss'] = log_det_tgt_Dloss log_vars['seg_src_acc'] = seg_acc.item() log_vars['det_src_acc'] = det_acc.item() self.optimizer.zero_grad() loss.backward() self.optimizer.step() # ------------------------ # train network on target: only GANLoss # ------------------------ self.optimizer.zero_grad() seg_tgt_feats, det_tgt_feats = self.model.forward_fusion( **tgt_data_batch) seg_disc_logits = self.seg_disc(seg_tgt_feats) # (N, 2) seg_disc_pred = seg_disc_logits.max(1)[1] # (N, ); cuda seg_label = torch.zeros_like(seg_disc_pred, dtype=torch.long).cuda() seg_acc = (seg_disc_pred == seg_label).float().mean() tgt_GANloss = None if seg_acc > acc_threshold: seg_tgt_loss = self.lambda_GANLoss * self.seg_disc.loss( seg_disc_logits, src=True) tgt_GANloss = seg_tgt_loss log_vars['seg_tgt_GANloss'] = seg_tgt_loss.item() det_disc_logits = self.det_disc(det_tgt_feats) # (4, 2, 49, 99) det_disc_pred = det_disc_logits.max(1)[1] # (4, 49, 99); cuda det_label = torch.zeros_like(det_disc_pred, dtype=torch.long).cuda() det_acc = (det_disc_pred == det_label).float().mean() if det_acc > acc_threshold: det_tgt_loss = self.lambda_GANLoss * self.det_disc.loss( det_disc_logits, src=True) log_vars['det_tgt_GANloss'] = det_tgt_loss.item() if tgt_GANloss is None: tgt_GANloss = det_tgt_loss else: tgt_GANloss += det_tgt_loss log_vars['seg_tgt_acc'] = seg_acc.item() log_vars['det_tgt_acc'] = det_acc.item() if tgt_GANloss is not None: tgt_GANloss.backward() self.optimizer.step() # after_train_iter callback self.log_buffer.update(log_vars, num_samples) self.call_hook('after_train_iter') # optimizer hook && logger hook self._iter += 1 self.call_hook('after_train_epoch') self._epoch += 1
def train(self, src_data_loader, tgt_data_loader): self.model.train() self.mode = 'train' self.data_loader = src_data_loader if self.with_disc: self.seg_disc.train() tgt_data_iter = iter(tgt_data_loader) assert len(tgt_data_loader) <= len(src_data_loader) self._max_iters = self._max_epochs * len(self.data_loader) self.call_hook('before_train_epoch') time.sleep(2) # Prevent possible deadlock during epoch transition for i, src_data_batch in enumerate(self.data_loader): # before train iter self._inner_iter = i self.call_hook('before_train_iter') if self.with_disc: # fetch target data batch tgt_data_batch = next(tgt_data_iter, None) if tgt_data_batch is None: tgt_data_iter = iter(tgt_data_loader) tgt_data_batch = next(tgt_data_iter) # ------------------------ # train Discriminators # ------------------------ if self.with_disc: set_requires_grad(self.seg_disc, requires_grad=True) # seg_feats_before_fusion: (N, 64, 225, 400); seg_feats_after_fusion: (N, 128) # src segmentation img_feats, fusion_feats = self.model.extract_feat( **src_data_batch) src_feats = self.get_feats(img_feats, fusion_feats) src_logits = self.seg_disc(src_feats) src_Dloss = self.seg_disc.loss(src_logits, src=True) log_src_Dloss = src_Dloss.item() self.seg_opt.zero_grad() src_Dloss.backward() self.seg_opt.step() # tgt segmentation img_feats, fusion_feats = self.model.extract_feat( **tgt_data_batch) tgt_feats = self.get_feats(img_feats, fusion_feats) tgt_logits = self.seg_disc(tgt_feats) tgt_Dloss = self.seg_disc.loss(tgt_logits, src=False) log_tgt_Dloss = tgt_Dloss.item() self.seg_opt.zero_grad() tgt_Dloss.backward() self.seg_opt.step() # ------------------------ # train network on source: task loss + GANLoss # ------------------------ losses, img_feats, fusion_feats = self.model( **src_data_batch) # forward; losses: {'seg_loss'=} src_feats = self.get_feats(img_feats, fusion_feats) acc_threshold = 0.6 if self.with_disc: set_requires_grad(self.seg_disc, requires_grad=True) disc_logits = self.seg_disc(src_feats) # (N, 2) disc_pred = disc_logits.max(1)[1] # (N, ); cuda seg_label = torch.ones_like(disc_pred, dtype=torch.long).cuda() src_acc = (disc_pred == seg_label).float().mean() if src_acc > acc_threshold: losses[ 'src_GANloss'] = self.lambda_GANLoss * self.seg_disc.loss( disc_logits, src=False) loss, log_vars = parse_losses(losses) num_samples = len(src_data_batch['img_metas']) if self.with_disc: log_vars['src_Dloss'] = log_src_Dloss log_vars['tgt_Dloss'] = log_tgt_Dloss log_vars['src_acc'] = src_acc.item() self.optimizer.zero_grad() loss.backward() self.optimizer.step() # ------------------------ # train network on target: only GANLoss # ------------------------ if self.with_disc: img_feats, fusion_feats = self.model.extract_feat( **tgt_data_batch) tgt_feats = self.get_feats(img_feats, fusion_feats) disc_logits = self.seg_disc(tgt_feats) # (N, 2) disc_pred = disc_logits.max(1)[1] # (N, ); cuda seg_label = torch.zeros_like(disc_pred, dtype=torch.long).cuda() tgt_acc = (disc_pred == seg_label).float().mean() log_vars['tgt_acc'] = tgt_acc.item() if tgt_acc > acc_threshold: tgt_GANloss = self.lambda_GANLoss * self.seg_disc.loss( disc_logits, src=True) log_vars['tgt_GANloss'] = tgt_GANloss.item( ) # original tgt_loss self.optimizer.zero_grad() tgt_GANloss.backward() self.optimizer.step() # after_train_iter callback self.log_buffer.update(log_vars, num_samples) self.call_hook('after_train_iter') # optimizer hook && logger hook self._iter += 1 self.call_hook('after_train_epoch') self._epoch += 1
def train(self, src_data_loader, tgt_data_loader): self.model.train() self.seg_disc.train() self.det_disc.train() self.mode = 'train' self.data_loader = src_data_loader tgt_data_iter = iter(tgt_data_loader) assert len(tgt_data_loader) <= len(src_data_loader) self._max_iters = self._max_epochs * len(self.data_loader) self.call_hook('before_train_epoch') time.sleep(2) # Prevent possible deadlock during epoch transition for i, src_data_batch in enumerate(self.data_loader): # before train iter self._inner_iter = i self.call_hook('before_train_iter') # fetch target data batch tgt_data_batch = next(tgt_data_iter, None) if tgt_data_batch is None: tgt_data_iter = iter(tgt_data_loader) tgt_data_batch = next(tgt_data_iter) # ------------------------ # train on src # ------------------------ # train Discriminators # if acc(disc) > max: don't train disc losses, seg_fusion_feats, det_fusion_feats = self.model(**src_data_batch) # forward; losses: {'seg_loss'=} set_requires_grad([self.seg_disc, self.det_disc], requires_grad=True) seg_disc_logits = self.seg_disc(seg_fusion_feats.detach()) det_disc_logits = self.det_disc(det_fusion_feats.detach()) seg_disc_loss = self.seg_disc.loss(seg_disc_logits, src=True) det_disc_loss = self.det_disc.loss(det_disc_logits, src=True) disc_loss = seg_disc_loss + det_disc_loss self.seg_opt.zero_grad() self.det_opt.zero_grad() disc_loss.backward() self.seg_opt.step() self.det_opt.step() # train network # if acc(disc) > min: add GAN Loss set_requires_grad([self.seg_disc, self.det_disc], requires_grad=False) seg_disc_logits = self.seg_disc(seg_fusion_feats) # (N, 2) seg_disc_pred = seg_disc_logits.max(1)[1] # (N, ); cuda seg_label = torch.ones(len(seg_disc_pred), dtype=torch.long).cuda() seg_acc = (seg_disc_pred == seg_label).float().mean() if seg_acc > 0.6: losses['seg_src_loss'] = self.seg_disc.loss(seg_disc_logits, src=False) det_disc_logits = self.det_disc(det_fusion_feats) # (M, 2) det_disc_pred = det_disc_logits.max(1)[1] # (M, ); cuda det_label = torch.ones(len(det_disc_pred), dtype=torch.long).cuda() det_acc = (det_disc_pred == det_label).float().mean() if det_acc > 0.6: losses['det_src_loss'] = self.det_disc.loss(det_disc_logits, src=False) loss, log_vars = parse_losses(losses) num_samples = len(src_data_batch['img_metas']) log_vars['seg_src_acc'] = seg_acc.item() log_vars['det_src_acc'] = det_acc.item() self.optimizer.zero_grad() loss.backward() self.optimizer.step() # ------------------------ # train on tgt # ------------------------ # train Discriminators seg_fusion_feats, det_fusion_feats = self.model.forward_fusion(**tgt_data_batch) set_requires_grad([self.seg_disc, self.det_disc], requires_grad=True) seg_disc_logits = self.seg_disc(seg_fusion_feats.detach()) det_disc_logits = self.det_disc(det_fusion_feats.detach()) seg_disc_loss = self.seg_disc.loss(seg_disc_logits, src=False) det_disc_loss = self.det_disc.loss(det_disc_logits, src=False) disc_loss = seg_disc_loss + det_disc_loss self.seg_opt.zero_grad() self.det_opt.zero_grad() disc_loss.backward() self.seg_opt.step() self.det_opt.step() # train network on target domain without task loss self.optimizer.zero_grad() set_requires_grad([self.seg_disc, self.det_disc], requires_grad=False) seg_disc_logits = self.seg_disc(seg_fusion_feats) # (N, 2) seg_disc_pred = seg_disc_logits.max(1)[1] # (N, ); cuda seg_label = torch.zeros(len(seg_disc_pred), dtype=torch.long).cuda() seg_acc = (seg_disc_pred == seg_label).float().mean() if seg_acc > 0.6: seg_tgt_loss = self.seg_disc.loss(seg_disc_logits, src=True) seg_tgt_loss.backward() det_disc_logits = self.det_disc(det_fusion_feats) # (M, 2) det_disc_pred = det_disc_logits.max(1)[1] # (M, ); cuda det_label = torch.zeros(len(det_disc_pred), dtype=torch.long).cuda() det_acc = (det_disc_pred == det_label).float().mean() if det_acc > 0.6: det_tgt_loss = self.det_disc.loss(det_disc_logits, src=True) det_tgt_loss.backward() # accumulate grad log_vars['seg_tgt_acc'] = seg_acc.item() log_vars['det_tgt_acc'] = det_acc.item() self.optimizer.step() # after_train_iter callback self.log_buffer.update(log_vars, num_samples) self.call_hook('after_train_iter') # optimizer hook && logger hook self._iter += 1 self.call_hook('after_train_epoch') self._epoch += 1
def train(self, src_data_loader, tgt_data_loader): self.model.train() self.img_disc.train() self.lidar_disc.train() self.mode = 'train' self.data_loader = src_data_loader tgt_data_iter = iter(tgt_data_loader) assert len(tgt_data_loader) <= len(src_data_loader) self._max_iters = self._max_epochs * len(self.data_loader) self.call_hook('before_train_epoch') time.sleep(2) # Prevent possible deadlock during epoch transition for i, src_data_batch in enumerate(self.data_loader): # before train iter self._inner_iter = i self.call_hook('before_train_iter') # fetch target data batch tgt_data_batch = next(tgt_data_iter, None) if tgt_data_batch is None: tgt_data_iter = iter(tgt_data_loader) tgt_data_batch = next(tgt_data_iter) # ------------------------ # train Discriminators # ------------------------ set_requires_grad([self.img_disc, self.lidar_disc], requires_grad=True) # img_feats: (N, 64, 225, 400); lidar_feats: (N, 64); fusion_feats: (N, 128) # src src_img_feats, src_lidar_feats, src_fusion_feats = self.model.extract_feat( **src_data_batch) src_feats0 = src_img_feats src_feats1 = self.get_feats(src_lidar_feats, src_fusion_feats) src_logits0 = self.img_disc(src_feats0.detach()) src_Dloss0 = self.img_disc.loss(src_logits0, src=True) log_src_Dloss0 = src_Dloss0.item() src_logits1 = self.lidar_disc(src_feats1.detach()) src_Dloss1 = self.lidar_disc.loss(src_logits1, src=True) log_src_Dloss1 = src_Dloss1.item() # tgt segmentation tgt_img_feats, tgt_lidar_feats, tgt_fusion_feats = self.model.extract_feat( **tgt_data_batch) tgt_feats0 = tgt_img_feats tgt_feats1 = self.get_feats(tgt_lidar_feats, tgt_fusion_feats) tgt_logits0 = self.img_disc(tgt_feats0.detach()) tgt_Dloss0 = self.img_disc.loss(tgt_logits0, src=False) log_tgt_Dloss0 = tgt_Dloss0.item() tgt_logits1 = self.lidar_disc(tgt_feats1.detach()) tgt_Dloss1 = self.lidar_disc.loss(tgt_logits1, src=False) log_tgt_Dloss1 = tgt_Dloss1.item() # backward img_Dloss = (src_Dloss0 + tgt_Dloss0) * 0.5 lidar_Dloss = (src_Dloss1 + tgt_Dloss1) * 0.5 self.img_opt.zero_grad() img_Dloss.backward() self.img_opt.step() self.lidar_opt.zero_grad() lidar_Dloss.backward() self.lidar_opt.step() # ------------------------ # network forward on source: task loss + self.lambda_GANLoss * GANLoss # ------------------------ losses, src_img_feats, src_lidar_feats, src_fusion_feats = self.model( **src_data_batch) src_feats0 = src_img_feats src_feats1 = self.get_feats(src_lidar_feats, src_fusion_feats) set_requires_grad([self.img_disc, self.lidar_disc], requires_grad=False) src_logits0 = self.img_disc(src_feats0) src_pred0 = src_logits0.max(1)[1] src_label0 = torch.ones_like(src_pred0, dtype=torch.long).cuda() src_acc0 = (src_pred0 == src_label0).float().mean() if src_acc0 > self.src_acc_threshold: losses[ 'src_img_GANloss'] = self.lambda_img * self.img_disc.loss( src_logits0, src=False) src_logits1 = self.lidar_disc(src_feats1) src_pred1 = src_logits1.max(1)[1] src_label1 = torch.ones_like(src_pred1, dtype=torch.long).cuda() src_acc1 = (src_pred1 == src_label1).float().mean() if src_acc1 > self.src_acc_threshold: losses[ 'src_lidar_GANloss'] = self.lambda_lidar * self.lidar_disc.loss( src_logits1, src=False) # ------------------------ # network forward on target: only GANLoss # ------------------------ tgt_img_feats, tgt_lidar_feats, tgt_fusion_feats = self.model.extract_feat( **tgt_data_batch) tgt_feats0 = tgt_img_feats tgt_feats1 = self.get_feats(tgt_img_feats, tgt_fusion_feats) tgt_logits0 = self.img_disc(tgt_feats0) tgt_pred0 = tgt_logits0.max(1)[1] tgt_label0 = torch.zeros_like(tgt_pred0, dtype=torch.long).cuda() tgt_acc0 = (tgt_pred0 == tgt_label0).float().mean() if tgt_acc0 > self.tgt_acc_threshold: losses[ 'tgt_img_GANloss'] = self.lambda_img * self.img_disc.loss( tgt_logits0, src=True) tgt_logits1 = self.lidar_disc(tgt_feats1) tgt_pred1 = tgt_logits1.max(1)[1] tgt_label1 = torch.zeros_like(tgt_pred1, dtype=torch.long).cuda() tgt_acc1 = (tgt_pred1 == tgt_label1).float().mean() if tgt_acc1 > self.tgt_acc_threshold: losses[ 'tgt_lidar_GANloss'] = self.lambda_lidar * self.lidar_disc.loss( tgt_logits1, src=True) loss, log_vars = parse_losses(losses) num_samples = len(src_data_batch['img_metas']) log_vars['src_img_Dloss'] = log_src_Dloss0 log_vars['src_lidar_Dloss'] = log_src_Dloss1 log_vars['tgt_img_Dloss'] = log_tgt_Dloss0 log_vars['tgt_lidar_Dloss'] = log_tgt_Dloss1 log_vars['src_img_acc'] = src_acc0.item() log_vars['src_lidar_acc'] = src_acc1.item() log_vars['tgt_img_acc'] = tgt_acc0.item() log_vars['tgt_lidar_acc'] = tgt_acc1.item() # ------------------------ # network backward: src_task_loss + self.lambda_GANLoss * tgt_GANloss # ------------------------ self.optimizer.zero_grad() loss.backward() self.optimizer.step() # after_train_iter callback self.log_buffer.update(log_vars, num_samples) self.call_hook('after_train_iter') # optimizer hook && logger hook self._iter += 1 self.call_hook('after_train_epoch') self._epoch += 1