def main(): args = arg_parser.Parse() os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus logger = Logger(args.log_dir) logger.PrintAndLogArgs(args) saver = ImageAndLossSaver(args.tb_logs_dir, logger.log_folder, args.checkpoints_dir, args.save_pics_every) source_train_loader = CreateSrcDataLoader(args, 'train_semseg') source_val_loader = CreateSrcDataLoader(args, 'val_semseg') semseg_net, semseg_optimizer = CreateModel(args) semseg_net = nn.DataParallel(semseg_net.cuda()) semseg_scheduler = torch.optim.lr_scheduler.MultiStepLR( semseg_optimizer, milestones=np.arange(0, args.num_epochs, 10), gamma=0.9) logger.info('######### Network created #########') logger.info('Architecture of Semantic Segmentation network:\n' + str(semseg_net)) for epoch in range(args.num_epochs): semseg_net.train() saver.Reset() logger.info('#################[Epoch %d]#################' % (epoch + 1)) for batch_num, (src_img, src_lbl, _, _) in enumerate(source_train_loader): start_time = time.time() semseg_optimizer.zero_grad() src_input_batch = Variable(src_img, requires_grad=False).cuda() src_label_batch = Variable(src_lbl, requires_grad=False).cuda() predicted, loss_seg, loss_ent = semseg_net( src_input_batch, lbl=src_label_batch) # F(G(S.T)) pred_label = torch.argmax(predicted, dim=1) loss = torch.mean(loss_seg + args.entW * loss_ent) saver.WriteSemsegLossHistory(args.model, loss.item()) loss.backward() semseg_optimizer.step() saver.running_time += time.time() - start_time if saver.SaveImagesSemsegIteration: saver.SaveTrainSemegImages(epoch, src_img[0, :, :, :], src_lbl[0, :, :], pred_label[0, :, :]) if (batch_num + 1) % args.print_every == 0: logger.info('Finished Batch %d' % (batch_num + 1)) # Update LR: semseg_scheduler.step() #Save checkpoint: saver.SaveModelsCheckpointSemseg(semseg_net, args.model, epoch) #Validation: semseg_net.eval() rand_samp_inds = np.random.randint(0, len(source_val_loader.dataset), 5) rand_batchs = np.floor(rand_samp_inds / args.batch_size).astype(np.int) cm = torch.zeros((NUM_CLASSES, NUM_CLASSES)).cuda() for val_batch_num, (src_img, src_lbl, _, _) in enumerate(source_val_loader): with torch.no_grad(): src_input_batch = Variable(src_img, requires_grad=False).cuda() src_label_batch = Variable(src_lbl, requires_grad=False).cuda() pred_softs_batch = semseg_net(src_input_batch) pred_batch = torch.argmax(pred_softs_batch, dim=1) cm += compute_cm_batch_torch(pred_batch, src_label_batch, IGNORE_LABEL, NUM_CLASSES) if (val_batch_num + 1) in rand_batchs: rand_offset = np.random.randint(0, args.batch_size) saver.SaveValidationImages( epoch, src_input_batch[rand_offset, :, :, :], src_label_batch[rand_offset, :, :], pred_batch[rand_offset, :, :]) iou, miou = compute_iou_torch(cm) saver.SaveEpochAccuracy(iou, miou, epoch) logger.info( 'Average accuracy of Epoch #%d on target domain: mIoU = %2f' % (epoch + 1, miou)) logger.info( '-----------------------------------Epoch #%d Finished-----------------------------------' % (epoch + 1)) del cm, pred_softs_batch, pred_batch saver.tb.close() logger.info('Finished training.')
def main(): args = arg_parser.Parse() os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus logger = Logger(args.log_dir) logger.PrintAndLogArgs(args) saver = ImageAndLossSaver(args.tb_logs_dir, logger.log_folder, args.checkpoints_dir, args.save_pics_every) source_loader, target_train_loader, target_eval_loader = CreateSrcDataLoader( args), CreateTrgDataLoader(args, 'train'), CreateTrgDataLoader(args, 'val') epoch_size = np.maximum(len(target_train_loader.dataset), len(source_loader.dataset)) steps_per_epoch = int(np.floor(epoch_size / args.batch_size)) source_loader.dataset.SetEpochSize(epoch_size) target_train_loader.dataset.SetEpochSize(epoch_size) generator = model.DeepLPFNet() generator = nn.DataParallel(generator.cuda()) generator_criterion = model.GeneratorLoss() generator_optimizer = optim.Adam(generator.parameters(), lr=args.generator_lr, betas=(0.9, 0.999), eps=1e-08) discriminator = model.Discriminator() discriminator = nn.DataParallel(discriminator.cuda()) discriminator_criterion = model.DiscriminatorLoss() discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=args.discriminator_lr, betas=(0.9, 0.999), eps=1e-08) semseg_net, semseg_optimizer = CreateModel(args) semseg_net = nn.DataParallel(semseg_net.cuda()) logger.info('######### Network created #########') logger.info('Architecture of Generator:\n' + str(generator)) logger.info('Architecture of Discriminator:\n' + str(discriminator)) logger.info('Architecture of Backbone net:\n' + str(semseg_net)) for epoch in range(args.num_epochs): generator.train() discriminator.train() semseg_net.train() saver.Reset() discriminate_src = True source_loader_iter, target_train_loader_iter, target_eval_loader_iter = iter( source_loader), iter(target_train_loader), iter(target_eval_loader) logger.info('#################[Epoch %d]#################' % (epoch + 1)) for batch_num in range(steps_per_epoch): start_time = time.time() training_discriminator = (batch_num >= args.generator_boost) and ( batch_num - args.generator_boost) % ( args.discriminator_iters + args.generator_iters) < args.discriminator_iters src_img, src_lbl, src_shapes, src_names = source_loader_iter.next( ) # new batch source trg_eval_img, trg_eval_lbl, trg_shapes, trg_names = target_train_loader_iter.next( ) # new batch target generator_optimizer.zero_grad() discriminator_optimizer.zero_grad() semseg_optimizer.zero_grad() src_input_batch = Variable(src_img, requires_grad=False).cuda() src_label_batch = Variable(src_lbl, requires_grad=False).cuda() trg_input_batch = Variable(trg_eval_img, requires_grad=False).cuda() # trg_label_batch = Variable(trg_lbl, requires_grad=False).cuda() src_in_trg = generator(src_input_batch, trg_input_batch) # G(S,T) if training_discriminator: #train discriminator if discriminate_src == True: discriminator_src_in_trg = discriminator( src_in_trg) # D(G(S,T)) discriminator_trg = None # D(T) else: discriminator_src_in_trg = None # D(G(S,T)) discriminator_trg = discriminator(trg_input_batch) # D(T) discriminate_src = not discriminate_src loss = discriminator_criterion(discriminator_src_in_trg, discriminator_trg) else: #train generator and semseg net discriminator_trg = discriminator(trg_input_batch) # D(T) predicted, loss_seg, loss_ent = semseg_net( src_in_trg, lbl=src_label_batch) # F(G(S.T)) src_in_trg_labels = torch.argmax(predicted, dim=1) loss = generator_criterion(loss_seg, loss_ent, args.entW, discriminator_trg) saver.WriteLossHistory(training_discriminator, loss.item()) loss.backward() if training_discriminator: # train discriminator discriminator_optimizer.step() else: # train generator and semseg net generator_optimizer.step() semseg_optimizer.step() saver.running_time += time.time() - start_time if (not training_discriminator) and saver.SaveImagesIteration: saver.SaveTrainImages(epoch, src_img[0, :, :, :], src_in_trg[0, :, :, :], src_lbl[0, :, :], src_in_trg_labels[0, :, :]) if (batch_num + 1) % args.print_every == 0: logger.PrintAndLogData(saver, epoch, batch_num, args.print_every) if (batch_num + 1) % args.save_checkpoint == 0: saver.SaveModelsCheckpoint(semseg_net, discriminator, generator, epoch, batch_num) #Validation: semseg_net.eval() rand_samp_inds = np.random.randint(0, len(target_eval_loader.dataset), 5) rand_batchs = np.floor(rand_samp_inds / args.batch_size).astype(np.int) cm = torch.zeros((NUM_CLASSES, NUM_CLASSES)).cuda() for val_batch_num, (trg_eval_img, trg_eval_lbl, _, _) in enumerate(target_eval_loader): with torch.no_grad(): trg_input_batch = Variable(trg_eval_img, requires_grad=False).cuda() trg_label_batch = Variable(trg_eval_lbl, requires_grad=False).cuda() pred_softs_batch = semseg_net(trg_input_batch) pred_batch = torch.argmax(pred_softs_batch, dim=1) cm += compute_cm_batch_torch(pred_batch, trg_label_batch, IGNORE_LABEL, NUM_CLASSES) print('Validation: saw', val_batch_num * args.batch_size, 'examples') if (val_batch_num + 1) in rand_batchs: rand_offset = np.random.randint(0, args.batch_size) saver.SaveValidationImages( epoch, trg_input_batch[rand_offset, :, :, :], trg_label_batch[rand_offset, :, :], pred_batch[rand_offset, :, :]) iou, miou = compute_iou_torch(cm) saver.SaveEpochAccuracy(iou, miou, epoch) logger.info( 'Average accuracy of Epoch #%d on target domain: mIoU = %2f' % (epoch + 1, miou)) logger.info( '-----------------------------------Epoch #%d Finished-----------------------------------' % (epoch + 1)) del cm, trg_input_batch, trg_label_batch, pred_softs_batch, pred_batch saver.tb.close() logger.info('Finished training.')