def train_generator( model_G, model_D, trainloader, optimizer_G, train_dataset_size, device ): """ Train generator (segmentation network), including loss_ce and loss_adv with GT. :return: loss_ce, loss_adv """ loss_ce_value = [] loss_adv_value = [] NUM_BATCHES = np.floor(train_dataset_size/BATCH_SIZE) for i, mini_batch in tqdm.tqdm(enumerate(trainloader), total=NUM_BATCHES): # don't accumulate grads in D for param in model_G.parameters(): param.requires_grad = True for param in model_D.parameters(): param.requires_grad = False optimizer_G.zero_grad() points, cls_gt, seg_gt = mini_batch points, cls_gt, seg_gt = Variable(points).float(), \ Variable(cls_gt).float(), \ Variable(seg_gt).type(torch.LongTensor) points, cls_gt, seg_gt = points.to(device), \ cls_gt.to(device), \ seg_gt.to(device) pred = model_G(points, cls_gt) # loss_ce loss_ce = loss_calc(pred, seg_gt, device, mask=False) # loss_adv D_out = model_D(F.softmax(pred, dim=2)) ignore_mask = np.zeros(seg_gt.shape).astype(np.bool) loss_adv = loss_bce(D_out, make_D_label(GT_LABEL, ignore_mask, device), device) loss_seg = loss_ce + LAMBDA_ADV * loss_adv loss_seg.backward() optimizer_G.step() loss_ce_value.append(loss_ce.item()) loss_adv_value.append(loss_adv.item()) return np.average(loss_ce_value), np.average(loss_adv_value)
def train(self, src_loader, tar_loader, val_loader): loss_rot = loss_adv = loss_weight = loss_D_s = loss_D_t = 0 args = self.args log = self.logger device = self.device interp_source = nn.Upsample(size=(args.datasets.source.images_size[1], args.datasets.source.images_size[0]), mode='bilinear', align_corners=True) interp_target = nn.Upsample(size=(args.datasets.target.images_size[1], args.datasets.target.images_size[0]), mode='bilinear', align_corners=True) interp_prediction = nn.Upsample(size=(args.auxiliary.images_size[1], args.auxiliary.images_size[0]), mode='bilinear', align_corners=True) source_iter = enumerate(src_loader) target_iter = enumerate(tar_loader) self.model.train() self.model = self.model.to(device) if args.method.adversarial: self.model_D.train() self.model_D = self.model_D.to(device) if args.method.self: self.model_A.train() self.model_A = self.model_A.to(device) log.info('########### TRAINING STARTED ############') start = time.time() for i_iter in range(self.start_iter, self.num_steps): self.model.train() self.optimizer.zero_grad() adjust_learning_rate(self.optimizer, self.preheat, args.num_steps, args.power, i_iter, args.model.optimizer) if args.method.adversarial: self.model_D.train() self.optimizer_D.zero_grad() adjust_learning_rate(self.optimizer_D, self.preheat, args.num_steps, args.power, i_iter, args.discriminator.optimizer) if args.method.self: self.model_A.train() self.optimizer_A.zero_grad() adjust_learning_rate(self.optimizer_A, self.preheat, args.num_steps, args.power, i_iter, args.auxiliary.optimizer) damping = (1 - i_iter / self.num_steps ) # similar to early stopping # ====================================================================================== # train G # ====================================================================================== if args.method.adversarial: for param in self.model_D.parameters(): # Remove Grads in D param.requires_grad = False # Train with Source _, batch = next(source_iter) images_s, labels_s, _, _ = batch images_s = images_s.to(device) pred_source1_, pred_source2_ = self.model(images_s) pred_source1 = interp_source(pred_source1_) pred_source2 = interp_source(pred_source2_) # Segmentation Loss loss_seg = ( loss_calc(self.num_classes, pred_source1, labels_s, device) + loss_calc(self.num_classes, pred_source2, labels_s, device)) loss_seg.backward() self.losses['seg'].append(loss_seg.item()) # Train with Target _, batch = next(target_iter) images_t, labels_t = batch images_t = images_t.to(device) pred_target1_, pred_target2_ = self.model(images_t) pred_target1 = interp_target(pred_target1_) pred_target2 = interp_target(pred_target2_) # Semi-supervised approach if args.use_target_labels and i_iter % int( 1 / args.target_frac) == 0: loss_seg_t = (loss_calc(args.num_classes, pred_target1, labels_t, device) + loss_calc(args.num_classes, pred_target2, labels_t, device)) loss_seg_t.backward() self.losses['seg_t'].append(loss_seg_t.item()) # Adversarial Loss if args.method.adversarial: pred_target1 = pred_target1.detach() pred_target2 = pred_target2.detach() weight_map = weightmap(F.softmax(pred_target1, dim=1), F.softmax(pred_target2, dim=1)) D_out = interp_target( self.model_D(F.softmax(pred_target1 + pred_target2, dim=1))) # Adaptive Adversarial Loss if i_iter > self.preheat: loss_adv = self.weighted_bce_loss( D_out, torch.FloatTensor(D_out.data.size()).fill_( self.source_label).to(device), weight_map, args.Epsilon, args.Lambda_local) else: loss_adv = self.bce_loss( D_out, torch.FloatTensor(D_out.data.size()).fill_( self.source_label).to(device)) loss_adv.requires_grad = True loss_adv = loss_adv * self.args.Lambda_adv * damping loss_adv.backward() self.losses['adv'].append(loss_adv.item()) # Weight Discrepancy Loss if args.weight_loss: W5 = None W6 = None if args.model.name == 'DeepLab': # TODO: ADD ERF-NET for (w5, w6) in zip(self.model.layer5.parameters(), self.model.layer6.parameters()): if W5 is None and W6 is None: W5 = w5.view(-1) W6 = w6.view(-1) else: W5 = torch.cat((W5, w5.view(-1)), 0) W6 = torch.cat((W6, w6.view(-1)), 0) loss_weight = (torch.matmul(W5, W6) / (torch.norm(W5) * torch.norm(W6)) + 1 ) # +1 is for a positive loss loss_weight = loss_weight * args.Lambda_weight * damping * 2 loss_weight.backward() self.losses['weight'].append(loss_weight.item()) # ====================================================================================== # train D # ====================================================================================== if args.method.adversarial: # Bring back Grads in D for param in self.model_D.parameters(): param.requires_grad = True # Train with Source pred_source1 = pred_source1.detach() pred_source2 = pred_source2.detach() D_out_s = interp_source( self.model_D(F.softmax(pred_source1 + pred_source2, dim=1))) loss_D_s = self.bce_loss( D_out_s, torch.FloatTensor(D_out_s.data.size()).fill_( self.source_label).to(device)) loss_D_s.backward() self.losses['ds'].append(loss_D_s.item()) # Train with Target pred_target1 = pred_target1.detach() pred_target2 = pred_target2.detach() weight_map = weight_map.detach() D_out_t = interp_target( self.model_D(F.softmax(pred_target1 + pred_target2, dim=1))) # Adaptive Adversarial Loss if i_iter > self.preheat: loss_D_t = self.weighted_bce_loss( D_out_t, torch.FloatTensor(D_out_t.data.size()).fill_( self.target_label).to(device), weight_map, args.Epsilon, args.Lambda_local) else: loss_D_t = self.bce_loss( D_out_t, torch.FloatTensor(D_out_t.data.size()).fill_( self.target_label).to(device)) loss_D_t.backward() self.losses['dt'].append(loss_D_t.item()) # ====================================================================================== # Train SELF SUPERVISED TASK # ====================================================================================== if args.method.self: ''' SELF-SUPERVISED (ROTATION) ALGORITHM - Get squared prediction - Rotate it randomly (0,90,180,270) -> assign self-label (0,1,2,3) [*2 IF WANT TO CLASSIFY ALSO S/T] - Send rotated prediction to the classifier - Get loss - Update weights of classifier and G (segmentation network) ''' # Train with Source pred_source1 = pred_source1_.detach() pred_source2 = pred_source2_.detach() # Train with Target pred_target1 = pred_target1_.detach() pred_target2 = pred_target2_.detach() # pred_source = interp_prediction(F.softmax(pred_source1 + pred_source2, dim=1)) pred_target = interp_prediction( F.softmax(pred_target1 + pred_target2, dim=1)) # ROTATE TENSORS # source label_source = torch.empty(1, dtype=torch.long).random_( args.auxiliary.aux_classes / 2).to(device) rotated_pred_source = rotate_tensor( pred_source, self.rotations[label_source.item()]) pred_source_label = self.model_A(rotated_pred_source) loss_rot_source = self.aux_loss(pred_source_label, label_source) # target label_target = torch.empty(1, dtype=torch.long).random_( args.auxiliary.aux_classes).to(device) rotated_pred_target = rotate_tensor( pred_target, self.rotations[label_target.item()]) pred_target_label = self.model_A(rotated_pred_target) loss_rot_target = self.aux_loss(pred_target_label, label_target) loss_rot = (loss_rot_source + loss_rot_target) * args.Lambda_aux #loss_rot = loss_rot_target * args.Lambda_aux loss_rot.backward() self.losses['aux'].append(loss_rot.item()) # Optimizers steps self.optimizer.step() if args.method.adversarial: self.optimizer_D.step() if args.method.self: self.optimizer_A.step() if i_iter % 10 == 0: log.info( 'Iter = {0:6d}/{1:6d}, loss_seg = {2:.4f} loss_rot = {3:.4f}, loss_adv = {4:.4f}, loss_weight = {5:.4f}, loss_D_s = {6:.4f} loss_D_t = {7:.4f}' .format(i_iter, self.num_steps, loss_seg, loss_rot, loss_adv, loss_weight, loss_D_s, loss_D_t)) if (i_iter % args.save_pred_every == 0 and i_iter != 0) or i_iter == self.num_steps - 1: log.info('saving weights...') i_iter = i_iter if i_iter != self.num_steps - 1 else i_iter + 1 # for last iter torch.save( self.model.state_dict(), join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth')) if args.method.adversarial: torch.save( self.model_D.state_dict(), join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth')) if args.method.self: torch.save( self.model_A.state_dict(), join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_Aux.pth')) self.validate(i_iter, val_loader) compute_mIoU(i_iter, args.datasets.target.val.label_dir, self.save_path, args.datasets.target.json_file, args.datasets.target.base_list, args.results_dir) save_losses_plot(args.results_dir, self.losses) #SAVE ALSO IMAGES OF SOURCE AND TARGET save_segmentations(args.images_dir, images_s, labels_s, pred_source1, images_t) del images_s, labels_s, pred_source1, pred_source2, pred_source1_, pred_source2_ del images_t, labels_t, pred_target1, pred_target2, pred_target1_, pred_target2_ end = time.time() days = int((end - start) / 86400) log.info( 'Total training time: {} days, {} hours, {} min, {} sec '.format( days, int((end - start) / 3600) - (days * 24), int((end - start) / 60 % 60), int((end - start) % 60))) print('### Experiment: ' + args.experiment + ' finished ###')
def train_semi( model_G, model_D, trainloader_remain, optimizer_G, train_dataset_size, device ): """ Train when GT is NOT available, needs to train loss_semi_adv and loss_semi (for generator). :return: loss_semi_value = loss_semi_adv + loss_semi """ loss_semi_adv_value = [] loss_semi_value = [] NUM_BATCHES = np.floor(train_dataset_size / BATCH_SIZE) for i, mini_batch in tqdm.tqdm(enumerate(trainloader_remain), total = NUM_BATCHES): # don't accumulate grads in D for param in model_G.parameters(): param.requires_grad = True for param in model_D.parameters(): param.requires_grad = False optimizer_G.zero_grad() # only access to points points, cls, _ = mini_batch points, cls = Variable(points).float(), Variable(cls).float() points, cls = points.to(device), cls.to(device) pred = model_G(points, cls) pred_remain = pred.detach() D_out = model_D(F.softmax(pred, dim=2)) D_out_sigmoid = torch.sigmoid(D_out).data.cpu().numpy() # BxN ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(np.bool) # Bx2048 ### semi_adv ### loss_semi_adv = LAMBDA_SEMI_ADV * loss_bce(D_out, make_D_label( GT_LABEL, ignore_mask_remain, device), device ) ### semi ### semi_ignore_mask = (D_out_sigmoid < MASK_T) semi_gt = pred.data.cpu().numpy().argmax(axis=2) semi_gt[semi_ignore_mask] = 999 semi_ratio = 1.0 - float(semi_ignore_mask.sum()) / semi_ignore_mask.size print('semi ratio: {:.4f}'.format(semi_ratio)) if semi_ratio == 0.0: raise ValueError("Semi ratio == 0!") else: semi_gt = torch.FloatTensor(semi_gt) loss_semi = LAMBDA_SEMI * loss_calc(pred, semi_gt, device, mask=True) loss_semi_value.append(loss_semi.item()) loss_semi += loss_semi_adv loss_semi.backward() optimizer_G.step() loss_semi_adv_value.append(loss_semi_adv.item()) return pred_remain, ignore_mask_remain, np.average(loss_semi_adv_value), np.average(loss_semi_value)
def train(self, tar_loader, val_loader): loss_weight = 0 args = self.args log = self.logger device = self.device interp_target = nn.Upsample(size=(args.datasets.target.images_size[1], args.datasets.target.images_size[0]), mode='bilinear', align_corners=True) target_iter = enumerate(tar_loader) self.model.train() self.model = self.model.to(device) log.info('########### TRAINING STARTED ############') start = time.time() for i_iter in range(self.start_iter, self.num_steps): if i_iter % int(1 / args.target_frac) == 0: self.model.train() self.optimizer.zero_grad() adjust_learning_rate(self.optimizer, self.preheat, args.num_steps, args.power, i_iter, args.model.optimizer) damping = (1 - i_iter / self.num_steps ) # similar to early stopping # Train with Target _, batch = next(target_iter) images_t, labels_t = batch images_t = images_t.to(device) pred_target1, pred_target2 = self.model(images_t) pred_target1 = interp_target(pred_target1) pred_target2 = interp_target(pred_target2) loss_seg_t = (loss_calc(args.num_classes, pred_target1, labels_t, device) + loss_calc(args.num_classes, pred_target2, labels_t, device)) loss_seg_t.backward() self.losses['seg_t'].append(loss_seg_t.item()) # Weight Discrepancy Loss if args.weight_loss: W5 = None W6 = None # TODO: ADD ERF-NET if args.model.name == 'DeepLab': for (w5, w6) in zip(self.model.layer5.parameters(), self.model.layer6.parameters()): if W5 is None and W6 is None: W5 = w5.view(-1) W6 = w6.view(-1) else: W5 = torch.cat((W5, w5.view(-1)), 0) W6 = torch.cat((W6, w6.view(-1)), 0) # Cosine distance between W5 and W6 vectors loss_weight = (torch.matmul(W5, W6) / (torch.norm(W5) * torch.norm(W6)) + 1 ) # +1 is for a positive loss loss_weight = loss_weight * args.Lambda_weight * damping * 2 loss_weight.backward() self.losses['weight'].append(loss_weight.item()) # Optimizers steps self.optimizer.step() if i_iter % 10 == 0: log.info( 'Iter = {0:6d}/{1:6d}, loss_seg = {2:.4f}, loss_weight = {3:.4f}' .format(i_iter, self.num_steps, loss_seg_t, loss_weight)) if (i_iter % args.save_pred_every == 0 and i_iter != 0) or i_iter == self.num_steps - 1: log.info('saving weights...') i_iter = i_iter if i_iter != self.num_steps - 1 else i_iter + 1 # for last iter torch.save( self.model.state_dict(), join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth')) self.validate(i_iter, val_loader) compute_mIoU(i_iter, args.datasets.target.val.label_dir, self.save_path, args.datasets.target.json_file, args.datasets.target.base_list, args.results_dir) #save_losses_plot(args.results_dir, self.losses) # SAVE ALSO IMAGES OF SOURCE AND TARGET #save_segmentations(args.images_dir, images_s, labels_s, pred_source1, images_t) del images_t, labels_t, pred_target1, pred_target2 end = time.time() days = int((end - start) / 86400) log.info( 'Total training time: {} days, {} hours, {} min, {} sec '.format( days, int((end - start) / 3600) - (days * 24), int((end - start) / 60 % 60), int((end - start) % 60))) print('### Experiment: ' + args.experiment + ' finished ###')
def train(opts): if not os.path.exists(opts.snapshot_dir): os.makedirs(opts.snapshot_dir) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_G = create_model(type="generator", num_seg_classes=opts.num_seg_classes) model_D = create_model(type="discriminator", num_seg_classes=opts.num_seg_classes) model_G.to(device) model_G.train() model_D.to(device) model_D.train() train_dataset = create_dataset( num_inst_classes=NUM_INST_CLASSES, num_pts=NUM_PTS, mode="train", is_noise=IS_NOISE, is_rotate=IS_ROTATE, ) train_dataset_size = len(train_dataset) print("#Total train: {:6d}".format(train_dataset_size)) train_gt_dataset = create_GT_dataset( num_inst_classes=NUM_INST_CLASSES, num_pts=NUM_PTS, ) if opts.partial_data is None: trainloader = create_dataloader( dataset=train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=IS_SHUFFLE, pin_memory=True, ) trainloader_gt = create_dataloader( dataset=train_gt_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=IS_SHUFFLE, pin_memory=True, ) trainloader_iter = iter(trainloader) trainloader_gt_iter = iter(trainloader_gt) else: partial_size = int(opts.partial_data * train_dataset_size) if opts.partial_id is not None: train_ids = pickle.load(open(opts.partial_id)) print('loading train ids from {}'.format(opts.partial_id)) else: train_ids = list(range(train_dataset_size)) np.random.shuffle(train_ids) pickle.dump( train_ids, open(os.path.join(opts.snapshot_dir, 'train_id.pkl'), 'wb')) train_sampler = torch.utils.data.sampler.SubsetRandomSampler( train_ids[:partial_size]) train_remain_sampler = torch.utils.data.sampler.SubsetRandomSampler( train_ids[partial_size:]) train_gt_sampler = torch.utils.data.sampler.SubsetRandomSampler( train_ids[:partial_size]) trainloader = create_dataloader( dataset=train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=IS_SHUFFLE, pin_memory=True, sampler=train_sampler, ) trainloader_gt = create_dataloader( dataset=train_gt_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=IS_SHUFFLE, pin_memory=True, sampler=train_gt_sampler, ) trainloader_remain = create_dataloader( dataset=train_gt_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=IS_SHUFFLE, pin_memory=True, sampler=train_remain_sampler, ) trainloader_remain_iter = iter(trainloader_remain) trainloader_iter = iter(trainloader) trainloader_gt_iter = iter(trainloader_gt) # optimizer for segmentation network optimizer = optim.Adam(model_G.parameters(), lr=opts.lr_G) optimizer.zero_grad() # optimizer for discriminator network optimizer_D = optim.Adam(model_D.parameters(), lr=opts.lr_G, betas=(0.9, 0.999)) optimizer_D.zero_grad() # labels for adversarial training pred_label = 0 gt_label = 1 i_iter = 0 for epoch in np.arange(NUM_EPOCHS): loss_ce_value = 0 loss_adv_value = 0 loss_D_value = 0 loss_semi_value = 0 loss_semi_adv_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter, LR_G, NUM_EPOCHS * train_dataset_size / (BATCH_SIZE), POWER) optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter, LR_D, NUM_EPOCHS * train_dataset_size / (BATCH_SIZE), POWER) if epoch >= 0 and epoch <= 9: # only train generator for i, mini_batch in enumerate(trainloader): # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False points, cls_gt, seg_gt = mini_batch points, cls_gt, seg_gt = Variable(points).float(), \ Variable(cls_gt).float(), \ Variable(seg_gt).type(torch.LongTensor) points, cls_gt, seg_gt = points.to(device), \ cls_gt.to(device), \ seg_gt.to(device) pred = model_G(points, cls_gt) # loss_ce loss_ce = loss_calc(pred, seg_gt, device, mask=False) # loss_adv D_out = model_D(F.softmax(pred, dim=2)) ignore_mask = np.zeros(seg_gt.shape).astype(np.bool) loss_adv = loss_bce( D_out, make_D_label(gt_label, ignore_mask, device), device) loss_seg = loss_ce + LAMBDA_ADV * loss_adv loss_seg.backward() loss_ce_value += loss_ce.item() loss_adv_value += loss_adv.item() fastprint('[%d/%d] CE loss: %.3f, ADV loss: %.3f' % (epoch, NUM_EPOCHS, loss_ce_value, loss_adv_value)) elif epoch >= 10 and epoch <= 19: # only train discriminator for i, mini_batch in enumerate(trainloader): # don't accumulate grads in G for param in model_G.parameters(): param.requires_grad = False for param in model_D.parameters(): param.requires_grad = True points, cls_gt, seg_gt = mini_batch points, cls_gt, seg_gt = Variable(points).float(), \ Variable(cls_gt).float(), \ Variable(seg_gt).type(torch.LongTensor) points, cls_gt, seg_gt = points.to(device), \ cls_gt.to(device), \ seg_gt.to(device) ignore_mask_gt = np.zeros(seg_gt.shape).astype(np.bool) D_gt_v = Variable(one_hot(seg_gt, NUM_SEG_CLASSES)).float().to(device) D_out = model_D(D_gt_v) loss_D_gt = loss_bce( D_out, make_D_label(gt_label, ignore_mask_gt, device), device) ignore_mask = np.zeros(seg_gt.shape).astype(np.bool) pred = model_G(points, cls_gt) pred = pred.detach() D_out = model_D(F.softmax(pred, dim=2)) loss_D_pred = loss_bce( D_out, make_D_label(pred_label, ignore_mask, device), device) loss_D = loss_D_gt + loss_D_pred loss_D.backward() loss_D_value += loss_D.item() else: # start unlabeled data for i, mini_batch in enumerate(trainloader_remain): # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False # only access to img points, cls, _ = mini_batch points, cls = Variable(points).float(), Variable(cls).float() points, cls = points.to(device), cls.to(device) pred = model_G(points, cls) # BxNxC pred_remain = pred.detach() D_out = model_D(F.softmax(pred, dim=2)) D_out_sigmoid = torch.sigmoid(D_out).data.cpu().numpy() # BxN ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype( np.bool) # Bx2048 ### semi_adv ### loss_semi_adv = LAMBDA_SEMI_ADV * loss_bce( D_out, make_D_label(gt_label, ignore_mask_remain, device), device) ### semi ### semi_ignore_mask = (D_out_sigmoid < MASK_T) semi_gt = pred.data.cpu().numpy().argmax(axis=2) semi_gt[semi_ignore_mask] = 999 semi_ratio = 1.0 - float( semi_ignore_mask.sum()) / semi_ignore_mask.size print('semi ratio: {:.4f}'.format(semi_ratio)) if semi_ratio == 0.0: loss_semi_value += 0 raise ValueError("Semi ratio == 0!") else: semi_gt = torch.FloatTensor(semi_gt) loss_semi = LAMBDA_SEMI * loss_calc( pred, semi_gt, device, mask=True) loss_semi += loss_semi_adv loss_semi.backward() loss_semi_adv_value += loss_semi_adv.item() loss_semi_value += loss_semi.item()