def main(params): # basic parameters parser = argparse.ArgumentParser() parser.add_argument('--checkpoint_path', type=str, default=None, required=True, help='The path to the pretrained weights of model') parser.add_argument('--crop_height', type=int, default=640, help='Height of cropped/resized input image to network') parser.add_argument('--crop_width', type=int, default=640, help='Width of cropped/resized input image to network') parser.add_argument('--data', type=str, default='/path/to/data', help='Path of training data') parser.add_argument('--batch_size', type=int, default=1, help='Number of images in each batch') parser.add_argument('--context_path', type=str, default="resnet101", help='The context path model you are using.') parser.add_argument('--cuda', type=str, default='0', help='GPU ids used for training') parser.add_argument('--use_gpu', type=bool, default=True, help='Whether to user gpu for training') parser.add_argument('--num_classes', type=int, default=32, help='num of object classes (with void)') args = parser.parse_args(params) # create dataset and dataloader test_path = os.path.join(args.data, 'test') # test_path = os.path.join(args.data, 'train') test_label_path = os.path.join(args.data, 'test_labels') # test_label_path = os.path.join(args.data, 'train_labels') csv_path = os.path.join(args.data, 'class_dict.csv') dataset = CamVid(test_path, test_label_path, csv_path, scale=(args.crop_height, args.crop_width), mode='test') dataloader = DataLoader( dataset, batch_size=1, shuffle=True, num_workers=4, ) # build model os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda model = BiSeNet(args.num_classes, args.context_path) if torch.cuda.is_available() and args.use_gpu: model = torch.nn.DataParallel(model).cuda() # load pretrained model if exists print('load model from %s ...' % args.checkpoint_path) model.module.load_state_dict(torch.load(args.checkpoint_path)) print('Done!') # get label info label_info = get_label_info(csv_path) # test eval(model, dataloader, args, label_info)
def main(params): # basic parameters parser = argparse.ArgumentParser() parser.add_argument('--num_epochs', type=int, default=300, help='Number of epochs to train for') parser.add_argument('--epoch_start_i', type=int, default=0, help='Start counting epochs from this number') parser.add_argument('--checkpoint_step', type=int, default=5, help='How often to save checkpoints (epochs)') parser.add_argument('--validation_step', type=int, default=1, help='How often to perform validation (epochs)') parser.add_argument('--dataset', type=str, default="CamVid", help='Dataset you are using.') parser.add_argument( '--crop_height', type=int, default=640, help='Height of cropped/resized input image to network') parser.add_argument('--crop_width', type=int, default=640, help='Width of cropped/resized input image to network') parser.add_argument('--batch_size', type=int, default=1, help='Number of images in each batch') parser.add_argument('--context_path', type=str, default="resnet101", help='The context path model you are using.') parser.add_argument('--learning_rate', type=float, default=0.01, help='learning rate used for train') parser.add_argument('--data', type=str, default='/path/to/data', help='path of training data') parser.add_argument('--num_workers', type=int, default=4, help='num of workers') parser.add_argument('--num_classes', type=int, default=32, help='num of object classes (with void)') parser.add_argument('--cuda', type=str, default='0', help='GPU ids used for training') parser.add_argument('--use_gpu', type=bool, default=True, help='whether to user gpu for training') parser.add_argument('--pretrained_model_path', type=str, default=None, help='path to pretrained model') parser.add_argument('--save_model_path', type=str, default=None, help='path to save model') args = parser.parse_args(params) # create dataset and dataloader train_path = os.path.join(args.data, 'train') train_label_path = os.path.join(args.data, 'train_labels') val_path = os.path.join(args.data, 'val') val_label_path = os.path.join(args.data, 'val_labels') csv_path = os.path.join(args.data, 'class_dict.csv') dataset_train = CamVid(train_path, train_label_path, csv_path, scale=(args.crop_height, args.crop_width), mode='train') dataloader_train = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) dataset_val = CamVid(val_path, val_label_path, csv_path, scale=((args.crop_height, args.crop_width)), mode='val') dataloader_val = DataLoader( dataset_val, # this has to be 1 batch_size=1, shuffle=True, num_workers=args.num_workers) # build model os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda model = BiSeNet(args.num_classes, args.context_path) if torch.cuda.is_available() and args.use_gpu: model = torch.nn.DataParallel(model).cuda() # build optimizer optimizer = torch.optim.RMSprop(model.parameters(), args.learning_rate) # load pretrained model if exists if args.pretrained_model_path is not None: print('load model from %s ...' % args.pretrained_model_path) model.module.load_state_dict(torch.load(args.pretrained_model_path)) print('Done!') # train train(args, model, optimizer, dataloader_train, dataloader_val, csv_path)
def main(): # Call Python's garbage collector, and empty torch's CUDA cache. Just in case gc.collect() torch.cuda.empty_cache() # Enable cuDNN in benchmark mode. For more info see: # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936 torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True # Load Bisenet generator generator = BiSeNet(NUM_CLASSES, CONTEXT_PATH).cuda() # generator.load_state_dict(torch.load('./checkpoint_101_adversarial_both_augmentation_epoch_len_IDDA/37_Generator.pth')) generator.train() # Build discriminator discriminator = Discriminator(NUM_CLASSES).cuda() # discriminator.load_state_dict(torch.load('./checkpoint_101_adversarial_both_augmentation_epoch_len_IDDA/37_Discriminator.pth')) discriminator.train() # Load source dataset source_dataset = IDDA(image_path=IDDA_PATH, label_path=IDDA_LABEL_PATH, classes_info_path=JSON_IDDA_PATH, scale=(CROP_HEIGHT, CROP_WIDTH), loss=LOSS, mode='train') source_dataloader = DataLoader(source_dataset, batch_size=BATCH_SIZE_IDDA, shuffle=True, num_workers=NUM_WORKERS, drop_last=True, pin_memory=True) # Load target dataset target_dataset = CamVid(image_path=CAMVID_PATH, label_path=CAMVID_LABEL_PATH, csv_path=CSV_CAMVID_PATH, scale=(CROP_HEIGHT, CROP_WIDTH), loss=LOSS, mode='adversarial_train') target_dataloader = DataLoader(target_dataset, batch_size=BATCH_SIZE_CAMVID, shuffle=True, num_workers=NUM_WORKERS, drop_last=True, pin_memory=True) optimizer_BiSeNet = torch.optim.SGD(generator.parameters(), lr=LEARNING_RATE_SEGMENTATION, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY) optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=LEARNING_RATE_DISCRIMINATOR, betas=(0.9, 0.99)) # Loss for discriminator training # Sigmoid layer + BCELoss bce_loss = nn.BCEWithLogitsLoss() # Loss for segmentation loss # Log-softmax layer + 2D Cross Entropy cross_entropy_loss = CrossEntropy2d() # for epoch in range(NUM_STEPS): for epoch in range(1, 51): source_dataloader_iter = iter(source_dataloader) target_dataloader_iter = iter(target_dataloader) print(f'begin epoch {epoch}') # Initialize gradients=0 for Generator and Discriminator optimizer_BiSeNet.zero_grad() optimizer_discriminator.zero_grad() # Setting losses equal to 0 l_seg_to_print_acc, l_adv_to_print_acc, l_d_to_print_acc = 0, 0, 0 # Compute learning rate for this epoch adjust_learning_rate(optimizer_BiSeNet, LEARNING_RATE_SEGMENTATION, epoch, NUM_STEPS, POWER) adjust_learning_rate(optimizer_discriminator, LEARNING_RATE_DISCRIMINATOR, epoch, NUM_STEPS, POWER) for i in tqdm(range(len(target_dataloader))): optimizer_BiSeNet.zero_grad() optimizer_discriminator.zero_grad() l_seg_to_print, l_adv_to_print, l_d_to_print = minibatch( source_dataloader_iter, target_dataloader_iter, generator, discriminator, cross_entropy_loss, bce_loss, source_dataloader, target_dataloader) l_seg_to_print_acc += l_seg_to_print l_adv_to_print_acc += l_adv_to_print l_d_to_print_acc += l_d_to_print # Run optimizers using the gradient obtained via backpropagations optimizer_BiSeNet.step() optimizer_discriminator.step() # Output at each epoch print( f'epoch = {epoch}/{NUM_STEPS}, loss_seg = {l_seg_to_print_acc:.3f}, loss_adv = {l_adv_to_print_acc:.3f}, loss_D = {l_d_to_print_acc:.3f}' ) # Save intermediate generator (checkpoint) if epoch % CHECKPOINT_STEP == 0 and epoch != 0: # If the directory does not exists create it if not os.path.isdir(CHECKPOINT_PATH): os.mkdir(CHECKPOINT_PATH) # Save the parameters of the generator (segmentation network) and discriminator generator_checkpoint_path = os.path.join( CHECKPOINT_PATH, f"{BETA}_{epoch}_Generator.pth") torch.save(generator.state_dict(), generator_checkpoint_path) discriminator_checkpoint_path = os.path.join( CHECKPOINT_PATH, f"{BETA}_{epoch}_Discriminator.pth") torch.save(discriminator.state_dict(), discriminator_checkpoint_path) print( f"saved:\n{generator_checkpoint_path}\n{discriminator_checkpoint_path}" )
def main(params): # basic parameters parser = argparse.ArgumentParser() parser.add_argument('--num_epochs', type=int, default=300, help='Number of epochs to train for') parser.add_argument('--epoch_start_i', type=int, default=0, help='Start counting epochs from this number') parser.add_argument('--checkpoint_step', type=int, default=10, help='How often to save checkpoints (epochs)') parser.add_argument('--validation_step', type=int, default=2, help='How often to perform validation (epochs)') parser.add_argument('--dataset', type=str, default="CamVid", help='Dataset you are using.') parser.add_argument('--crop_height', type=int, default=720, help='Height of cropped/resized input image to network') parser.add_argument('--crop_width', type=int, default=960, help='Width of cropped/resized input image to network') parser.add_argument('--batch_size', type=int, default=32, help='Number of images in each batch') parser.add_argument('--context_path', type=str, default="resnet101", help='The context path model you are using, resnet18, resnet101.') parser.add_argument('--learning_rate_G', type=float, default=0.01, help='learning rate for G') parser.add_argument('--learning_rate_D', type=float, default=0.01, help='learning rate for D')#add lr_D 1e-4 parser.add_argument('--data_CamVid', type=str, default='', help='path of training data_CamVid') parser.add_argument('--data_IDDA', type=str, default='', help='path of training data_IDDA') parser.add_argument('--num_workers', type=int, default=4, help='num of workers') parser.add_argument('--num_classes', type=int, default=32, help='num of object classes (with void)') parser.add_argument('--cuda', type=str, default='0', help='GPU ids used for training') parser.add_argument('--use_gpu', type=bool, default=True, help='whether to user gpu for training') parser.add_argument('--pretrained_model_path', type=str, default=None, help='path to pretrained model') parser.add_argument('--save_model_path', type=str, default=None, help='path to save model') parser.add_argument('--optimizer_G', type=str, default='rmsprop', help='optimizer_G, support rmsprop, sgd, adam') parser.add_argument('--optimizer_D', type=str, default='rmsprop', help='optimizer_D, support rmsprop, sgd, adam') parser.add_argument('--loss', type=str, default='dice', help='loss function, dice or crossentropy') parser.add_argument('--loss_G', type=str, default='dice', help='loss function, dice or crossentropy') parser.add_argument('--lambda_adv', type=float, default=0.01, help='lambda coefficient for adversarial loss') args = parser.parse_args(params) # create dataset and dataloader for CamVid CamVid_train_path = [os.path.join(args.data_CamVid, 'train'), os.path.join(args.data_CamVid, 'val')] CamVid_train_label_path = [os.path.join(args.data_CamVid, 'train_labels'), os.path.join(args.data_CamVid, 'val_labels')] CamVid_test_path = os.path.join(args.data_CamVid, 'test') CamVid_test_label_path = os.path.join(args.data_CamVid, 'test_labels') CamVid_csv_path = os.path.join(args.data_CamVid, 'class_dict.csv') CamVid_dataset_train = CamVid(CamVid_train_path, CamVid_train_label_path, CamVid_csv_path, scale=(args.crop_height, args.crop_width), loss=args.loss, mode='train') CamVid_dataloader_train = DataLoader( CamVid_dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True ) CamVid_dataset_val = CamVid(CamVid_test_path, CamVid_test_label_path, CamVid_csv_path, scale=(args.crop_height, args.crop_width), loss=args.loss, mode='test') CamVid_dataloader_val = DataLoader( CamVid_dataset_val, # this has to be 1 batch_size=1, shuffle=True, num_workers=args.num_workers ) # create dataset and dataloader for IDDA IDDA_path = os.path.join(args.data_IDDA, 'rgb') IDDA_label_path = os.path.join(args.data_IDDA, 'labels') IDDA_info_path = os.path.join(args.data_IDDA, 'classes_info.json') IDDA_dataset = IDDA(IDDA_path, IDDA_label_path, IDDA_info_path, CamVid_csv_path, scale=(args.crop_height, args.crop_width), loss=args.loss) IDDA_dataloader = DataLoader( IDDA_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True ) # build model_G os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda model_G = BiSeNet(args.num_classes, args.context_path) if torch.cuda.is_available() and args.use_gpu: model_G = torch.nn.DataParallel(model_G).cuda() #build model_D model_D = DW_Discriminator(args.num_classes) if torch.cuda.is_available() and args.use_gpu: model_D = torch.nn.DataParallel(model_D).cuda() # build optimizer G if args.optimizer_G == 'rmsprop': optimizer_G = torch.optim.RMSprop(model_G.parameters(), args.learning_rate_G) elif args.optimizer_G == 'sgd': optimizer_G = torch.optim.SGD(model_G.parameters(), args.learning_rate_G, momentum=0.9, weight_decay=1e-4) elif args.optimizer_G == 'adam': optimizer_G = torch.optim.Adam(model_G.parameters(), args.learning_rate_G) else: # rmsprop print('not supported optimizer \n') return None # build optimizer D if args.optimizer_D == 'rmsprop': optimizer_D = torch.optim.RMSprop(model_D.parameters(), args.learning_rate_D) elif args.optimizer_D == 'sgd': optimizer_D = torch.optim.SGD(model_D.parameters(), args.learning_rate_D, momentum=0.9, weight_decay=1e-4) elif args.optimizer_D == 'adam': optimizer_D = torch.optim.Adam(model_D.parameters(), args.learning_rate_D) else: # rmsprop print('not supported optimizer \n') return None curr_epoch = 0 max_miou = 0 # load pretrained model if exists if args.pretrained_model_path is not None: print('load model from %s ...' % args.pretrained_model_path) state = torch.load(os.path.realpath(args.pretrained_model_path)) # upload the pretrained MODEL_G model_G.module.load_state_dict(state['model_G_state']) optimizer_G.load_state_dict(state['optimizer_G']) model_D.module.load_state_dict(state['model_D_state']) # upload the pretrained MODEL_D optimizer_D.load_state_dict(state['optimizer_D']) curr_epoch = state["epoch"] max_miou = state["max_miou"] print(str(curr_epoch - 1) + " already trained") print("start training from epoch " + str(curr_epoch)) print('Done!') # train train (args, model_G, model_D, optimizer_G, optimizer_D, CamVid_dataloader_train, CamVid_dataloader_val, IDDA_dataloader, curr_epoch, max_miou)