def test_D2L(): # set options opt = Option() opt.root_dir = root + '/dataset/test' opt.checkpoints_dir = root + '/checkpoints/D2L' opt.result_dir = opt.root_dir opt.gpu_ids = [0] opt.batch_size = 16 opt.coarse = False opt.pool_size = 0 opt.no_lsgan = True opt.is_train = False opt.fine_tune_sidewalk = False # load data root_dir_train = opt.root_dir dataset_train = D2LDataLoader(root_dir=root_dir_train, train=opt.is_train, coarse=opt.coarse, fine_tune_sidewalk=opt.fine_tune_sidewalk) data_loader_test = DataLoader(dataset_train, batch_size=opt.batch_size, shuffle=opt.shuffle, num_workers=opt.num_workers, pin_memory=opt.pin_memory) # load model model = D2LModel() model.initialize(opt) model.load_networks(50) # do testung for idx_batch, data_batch in enumerate(data_loader_test): print(idx_batch) model.set_input(data_batch, 0) model.forward() fake_S = model.fake_S.detach().cpu() n, c, h, w = fake_S.size() for i in range(0, n): sem = fake_S[i, :, :, :] * 0.5 + 0.5 img_id = data_batch['img_id'][i] # save image path_sem = root_dir_train + '/' + img_id + '_pred_sem_wo_mask.png' #torchvision.utils.save_image(depth.float(), path_depth) torchvision.utils.save_image(sem.float(), path_sem)
def train_D2L(): # set options opt = Option() opt.root_dir = root + '/dataset/D2L' opt.checkpoints_dir = root + '/checkpoints/D2L' opt.gpu_ids = [0] opt.batch_size = 16 opt.coarse = False opt.pool_size = 0 opt.no_lsgan = True opt.fine_tune_sidewalk = False # load data root_dir_train = opt.root_dir + '/train' dataset_train = D2LDataLoader(root_dir=root_dir_train, train=True, coarse=opt.coarse, fine_tune_sidewalk=opt.fine_tune_sidewalk) data_loader_train = DataLoader(dataset_train, batch_size=opt.batch_size, shuffle=opt.shuffle, num_workers=opt.num_workers, pin_memory=opt.pin_memory) # load model model = D2LModel() model.initialize(opt) model.load_networks(-1) # do training for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): file = open(opt.root_dir + '/logs.txt', 'a') for idx_batch, data_batch in enumerate(data_loader_train): print(idx_batch) model.set_input(data_batch, epoch) model.optimize_parameters() print('epoch: ' + str(epoch) + ', train loss_G_Loss: ' + str(model.loss_G.data)) file.write('epoch: ' + str(epoch) + ', train loss_G_Loss: ' + str(model.loss_G.data) + '\n') file.close() # save if epoch % 5 == 0: model.save_networks(epoch)