def execute(gpu, exp_batch, exp_alias): from time import gmtime, strftime manualSeed = g_conf.SEED torch.cuda.manual_seed(manualSeed) os.environ["CUDA_VISIBLE_DEVICES"] = gpu merge_with_yaml(os.path.join('configs', exp_batch, exp_alias + '.yaml')) set_type_of_process('train') coil_logger.add_message('Loading', {'GPU': gpu}) if not os.path.exists('_output_logs'): os.mkdir('_output_logs') sys.stdout = open(os.path.join( '_output_logs', g_conf.PROCESS_NAME + '_' + str(os.getpid()) + ".out"), "a", buffering=1) if monitorer.get_status(exp_batch, exp_alias + '.yaml', g_conf.PROCESS_NAME)[0] == "Finished": return full_dataset = os.path.join(os.environ["COIL_DATASET_PATH"], g_conf.TRAIN_DATASET_NAME) real_dataset = g_conf.TARGET_DOMAIN_PATH #main data loader dataset = CoILDataset(full_dataset, real_dataset, transform=transforms.Compose([transforms.ToTensor() ])) sampler = BatchSequenceSampler( splitter.control_steer_split(dataset.measurements, dataset.meta_data), g_conf.BATCH_SIZE, g_conf.NUMBER_IMAGES_SEQUENCE, g_conf.SEQUENCE_STRIDE) data_loader = torch.utils.data.DataLoader(dataset, batch_sampler=sampler, shuffle=False, num_workers=6, pin_memory=True) st = lambda aug: iag.Sometimes(aug, 0.4) oc = lambda aug: iag.Sometimes(aug, 0.3) rl = lambda aug: iag.Sometimes(aug, 0.09) augmenter = iag.Augmenter([iag.ToGPU()] + [ rl(iag.GaussianBlur( (0, 1.5))), # blur images with a sigma between 0 and 1.5 rl(iag.AdditiveGaussianNoise(loc=0, scale=( 0.0, 0.05), per_channel=0.5)), # add gaussian noise to images oc(iag.Dropout((0.0, 0.10), per_channel=0.5) ), # randomly remove up to X% of the pixels oc( iag.CoarseDropout( (0.0, 0.10), size_percent=(0.08, 0.2), per_channel=0.5)), # randomly remove up to X% of the pixels oc(iag.Add((-40, 40), per_channel=0.5) ), # change brightness of images (by -X to Y of original value) st(iag.Multiply((0.10, 2), per_channel=0.2) ), # change brightness of images (X-Y% of original value) rl(iag.ContrastNormalization( (0.5, 1.5), per_channel=0.5)), # improve or worsen the contrast rl(iag.Grayscale((0.0, 1))), # put grayscale ] # do all of the above in random order ) l1weight = g_conf.L1_WEIGHT task_adv_weight = g_conf.TASK_ADV_WEIGHT image_size = tuple([88, 200]) print(strftime("%Y-%m-%d %H:%M:%S", gmtime())) print("GPU", gpu) print("Configurations of ", exp_alias) print("GANMODEL_NAME", g_conf.GANMODEL_NAME) print("LOSS_FUNCTION", g_conf.LOSS_FUNCTION) print("LR_G, LR_D, LR", g_conf.LR_G, g_conf.LR_D, g_conf.LEARNING_RATE) print("SKIP", g_conf.SKIP) print("TYPE", g_conf.TYPE) print("L1 WEIGHT", g_conf.L1_WEIGHT) print("TASK ADV WEIGHT", g_conf.TASK_ADV_WEIGHT) print("LAB SMOOTH", g_conf.LABSMOOTH) if g_conf.GANMODEL_NAME == 'LSDcontrol': netD = ganmodels._netD(loss=g_conf.LOSS_FUNCTION).cuda() netG = ganmodels._netG(loss=g_conf.LOSS_FUNCTION, skip=g_conf.SKIP).cuda() elif g_conf.GANMODEL_NAME == 'LSDcontrol_nopatch': netD = ganmodels_nopatch._netD(loss=g_conf.LOSS_FUNCTION).cuda() netG = ganmodels_nopatch._netG(loss=g_conf.LOSS_FUNCTION).cuda() elif g_conf.GANMODEL_NAME == 'LSDcontrol_nopatch_smaller': netD = ganmodels_nopatch_smaller._netD( loss=g_conf.LOSS_FUNCTION).cuda() netG = ganmodels_nopatch_smaller._netG( loss=g_conf.LOSS_FUNCTION).cuda() elif g_conf.GANMODEL_NAME == 'LSDcontrol_task': netD_task = ganmodels_task._netD_task(loss=g_conf.LOSS_FUNCTION).cuda() netD_img = ganmodels_task._netD_img(loss=g_conf.LOSS_FUNCTION).cuda() netG = ganmodels_task._netG(loss=g_conf.LOSS_FUNCTION).cuda() netF = ganmodels_task._netF(loss=g_conf.LOSS_FUNCTION).cuda() if g_conf.PRETRAINED == 'RECON': netF_statedict = torch.load('netF_GAN_Pretrained.wts') netF.load_state_dict(netF_statedict) elif g_conf.PRETRAINED == 'IL': print("Loading IL") model_IL = torch.load('best_loss_20-06_EpicClearWeather.pth') model_IL_state_dict = model_IL['state_dict'] netF_state_dict = netF.state_dict() print(len(netF_state_dict.keys()), len(model_IL_state_dict.keys())) for i, keys in enumerate( zip(netF_state_dict.keys(), model_IL_state_dict.keys())): newkey, oldkey = keys # if newkey.split('.')[0] == "branch" and oldkey.split('.')[0] == "branches": # print("No Transfer of ", newkey, " to ", oldkey) # else: print("Transferring ", newkey, " to ", oldkey) netF_state_dict[newkey] = model_IL_state_dict[oldkey] netF.load_state_dict(netF_state_dict) print("IL Model Loaded!") elif g_conf.GANMODEL_NAME == 'LSDcontrol_task_2d': netD_bin = ganmodels_task._netD_task().cuda() netD_img = ganmodels_task._netD_img().cuda() netG = ganmodels_task._netG().cuda() netF = ganmodels_task._netF().cuda() if g_conf.PRETRAINED == 'IL': print("Loading IL") model_IL = torch.load(g_conf.IL_AGENT_PATH) model_IL_state_dict = model_IL['state_dict'] netF_state_dict = netF.state_dict() print(len(netF_state_dict.keys()), len(model_IL_state_dict.keys())) for i, keys in enumerate( zip(netF_state_dict.keys(), model_IL_state_dict.keys())): newkey, oldkey = keys print("Transferring ", newkey, " to ", oldkey) netF_state_dict[newkey] = model_IL_state_dict[oldkey] netF.load_state_dict(netF_state_dict) print("IL Model Loaded!") #### if g_conf.IF_AUG: print("Loading Aug Decoder") model_dec = torch.load(g_conf.DECODER_RECON_PATH) else: print("Loading Decoder") model_dec = torch.load(g_conf.DECODER_RECON_PATH) model_dec_state_dict = model_dec['stateG_dict'] netG_state_dict = netG.state_dict() print(len(netG_state_dict.keys()), len(model_dec_state_dict.keys())) for i, keys in enumerate( zip(netG_state_dict.keys(), model_dec_state_dict.keys())): newkey, oldkey = keys print("Transferring ", newkey, " to ", oldkey) netG_state_dict[newkey] = model_dec_state_dict[oldkey] netG.load_state_dict(netG_state_dict) print("Decoder Model Loaded!") init_weights(netD_bin) init_weights(netD_img) # init_weights(netG) print(netD_bin) print(netF) optimD_bin = torch.optim.Adam(netD_bin.parameters(), lr=g_conf.LR_D, betas=(0.5, 0.999)) optimD_img = torch.optim.Adam(netD_img.parameters(), lr=g_conf.LR_D, betas=(0.5, 0.999)) optimG = torch.optim.Adam(netG.parameters(), lr=g_conf.LR_D, betas=(0.5, 0.999)) if g_conf.TYPE == 'task': optimF = torch.optim.Adam(netF.parameters(), lr=g_conf.LEARNING_RATE) Task_Loss = TaskLoss() if g_conf.GANMODEL_NAME == 'LSDcontrol_task_2d': print("Using cross entropy!") Loss = torch.nn.CrossEntropyLoss().cuda() L1_loss = torch.nn.L1Loss().cuda() iteration = 0 best_loss_iter_F = 0 best_loss_iter_G = 0 best_lossF = 1000000.0 best_lossD = 1000000.0 best_lossG = 1000000.0 accumulated_time = 0 gen_iterations = 0 n_critic = g_conf.N_CRITIC lossF = Variable(torch.Tensor([100.0])) lossG_adv = Variable(torch.Tensor([100.0])) lossG_smooth = Variable(torch.Tensor([100.0])) lossG = Variable(torch.Tensor([100.0])) netD_bin.train() netD_img.train() netG.train() netF.train() capture_time = time.time() if not os.path.exists('./imgs_' + exp_alias): os.mkdir('./imgs_' + exp_alias) #TODO check how C network is optimized in LSDSEG #TODO put family for losses #IMPORTANT WHILE RUNNING THIS, CONV.PY MUST HAVE BATCHNORMS fake_img_pool_src = ImagePool(50) fake_img_pool_tgt = ImagePool(50) for data in data_loader: set_requires_grad(netD_bin, True) set_requires_grad(netD_img, True) set_requires_grad(netG, True) set_requires_grad(netF, True) # print("ITERATION:", iteration) val = 0.0 input_data, float_data, tgt_imgs = data if g_conf.IF_AUG: inputs = augmenter(0, input_data['rgb']) tgt_imgs = augmenter(0, tgt_imgs) else: inputs = input_data['rgb'].cuda() tgt_imgs = tgt_imgs.cuda() inputs = inputs.squeeze(1) inputs = inputs - val #subtracted by 0.5 tgt_imgs = tgt_imgs - val #subtracted by 0.5 controls = float_data[:, dataset.controls_position(), :] src_embed_inputs, src_branches = netF( inputs, dataset.extract_inputs(float_data).cuda()) tgt_embed_inputs = netF(tgt_imgs, None) src_img_fake = netG(src_embed_inputs) tgt_img_fake = netG(tgt_embed_inputs) if iteration % 100 == 0: imgs_to_save = torch.cat( (inputs[:1] + val, src_img_fake[:1] + val, tgt_imgs[:1] + val, tgt_img_fake[:1] + val), 0).cpu().data coil_logger.add_image("Images", imgs_to_save, iteration) imgs_to_save = imgs_to_save.clamp(0.0, 1.0) vutils.save_image(imgs_to_save, './imgs_' + exp_alias + '/' + str(iteration) + '_real_and_fake.png', normalize=False) ##--------------------Discriminator part!!!!!!!!!!-------------------## set_requires_grad(netD_bin, True) set_requires_grad(netD_img, False) set_requires_grad(netG, False) set_requires_grad(netF, False) optimD_bin.zero_grad() outputsD_real_src_bin = netD_bin(src_embed_inputs) outputsD_real_tgt_bin = netD_bin(tgt_embed_inputs) gradient_penalty = calc_gradient_penalty(netD_bin, src_embed_inputs, tgt_embed_inputs) lossD_bin = torch.mean(outputsD_real_tgt_bin - outputsD_real_src_bin) + gradient_penalty lossD_bin.backward(retain_graph=True) optimD_bin.step() coil_logger.add_scalar('Total LossD Bin', lossD_bin.data, iteration) #### Discriminator img update #### set_requires_grad(netD_bin, False) set_requires_grad(netD_img, True) set_requires_grad(netG, False) set_requires_grad(netF, False) optimD_img.zero_grad() outputsD_fake_src_img = netD_img(src_img_fake.detach()) outputsD_fake_tgt_img = netD_img(tgt_img_fake.detach()) outputsD_real_src_img = netD_img(inputs) outputsD_real_tgt_img = netD_img(tgt_imgs) gradient_penalty_src = calc_gradient_penalty(netD_img, inputs, src_img_fake) lossD_img_src = torch.mean( outputsD_fake_src_img - outputsD_real_src_img) + gradient_penalty_src gradient_penalty_tgt = calc_gradient_penalty(netD_img, tgt_imgs, tgt_img_fake) lossD_img_tgt = torch.mean( outputsD_fake_tgt_img - outputsD_real_tgt_img) + gradient_penalty_tgt lossD_img = (lossD_img_src + lossD_img_tgt) * 0.5 lossD_img.backward(retain_graph=True) optimD_img.step() coil_logger.add_scalar('Total LossD img', lossD_img.data, iteration) if ((iteration + 1) % n_critic) == 0: #####Generator updates####### set_requires_grad(netD_bin, False) set_requires_grad(netD_img, False) set_requires_grad(netG, True) set_requires_grad(netF, False) outputsD_fake_src_img = netD_img(src_img_fake) outputsD_real_tgt_img = netD_img(tgt_imgs) outputsD_fake_tgt_img = netD_img(tgt_img_fake) lossG_src_smooth = L1_loss(src_img_fake, inputs) lossG_tgt_smooth = L1_loss(tgt_img_fake, tgt_imgs) lossG_smooth = (lossG_src_smooth + lossG_tgt_smooth) * 0.5 lossG_adv = 0.5 * (-1.0 * outputsD_fake_src_img.mean() - 1.0 * outputsD_fake_tgt_img.mean()) lossG = (lossG_smooth + 0.0 * lossG_adv) lossG.backward(retain_graph=True) optimG.step() coil_logger.add_scalar('Total LossG', lossG.data, iteration) #####Task network updates########################## set_requires_grad(netD_bin, False) set_requires_grad(netD_img, False) set_requires_grad(netG, False) set_requires_grad(netF, True) optimF.zero_grad() src_embed_inputs, src_branches = netF( inputs, dataset.extract_inputs(float_data).cuda()) tgt_embed_inputs = netF(tgt_imgs, None) src_img_fake = netG(src_embed_inputs) tgt_img_fake = netG(tgt_embed_inputs) outputsD_fake_src_img = netD_img(src_img_fake) outputsD_real_tgt_img = netD_img(tgt_imgs) lossF_task = Task_Loss.MSELoss( src_branches, dataset.extract_targets(float_data).cuda(), controls.cuda(), dataset.extract_inputs(float_data).cuda()) lossF_adv_bin = netD_bin(src_embed_inputs).mean() - netD_bin( tgt_embed_inputs).mean() lossF_adv_img = outputsD_fake_src_img.mean( ) - outputsD_real_tgt_img.mean() lossF_adv = 0.5 * (lossF_adv_bin + 0.1 * lossF_adv_img) lossF = (lossF_task + task_adv_weight * lossF_adv) coil_logger.add_scalar('Total Task Loss', lossF.data, iteration) coil_logger.add_scalar('Adv Task Loss', lossF_adv.data, iteration) coil_logger.add_scalar('Only Task Loss', lossF_task.data, iteration) lossF.backward(retain_graph=True) optimF.step() if lossF.data < best_lossF: best_lossF = lossF.data.tolist() best_loss_iter_F = iteration #optimization for one iter done! position = random.randint(0, len(float_data) - 1) accumulated_time += time.time() - capture_time capture_time = time.time() if is_ready_to_save(iteration): state = { 'iteration': iteration, 'stateD_bin_dict': netD_bin.state_dict(), 'stateF_dict': netF.state_dict(), 'best_lossD': best_lossD, 'total_time': accumulated_time, 'best_loss_iter_F': best_loss_iter_F } torch.save( state, os.path.join('/datatmp/Experiments/rohitgan/_logs', exp_batch, exp_alias, 'checkpoints', str(iteration) + '.pth')) if iteration == best_loss_iter_F and iteration > 10000: state = { 'iteration': iteration, 'stateD_bin_dict': netD_bin.state_dict(), 'stateF_dict': netF.state_dict(), 'best_lossD': best_lossD, 'best_lossF': best_lossF, 'total_time': accumulated_time, 'best_loss_iter_F': best_loss_iter_F } torch.save( state, os.path.join('/datatmp/Experiments/rohitgan/_logs', exp_batch, exp_alias, 'best_modelF' + '.pth')) iteration += 1
def execute(gpu, exp_batch, exp_alias): from time import gmtime, strftime manualSeed = g_conf.SEED torch.cuda.manual_seed(manualSeed) os.environ["CUDA_VISIBLE_DEVICES"] = gpu merge_with_yaml(os.path.join('configs', exp_batch, exp_alias + '.yaml')) set_type_of_process('train') coil_logger.add_message('Loading', {'GPU': gpu}) if not os.path.exists('_output_logs'): os.mkdir('_output_logs') sys.stdout = open(os.path.join( '_output_logs', g_conf.PROCESS_NAME + '_' + str(os.getpid()) + ".out"), "a", buffering=1) if monitorer.get_status(exp_batch, exp_alias + '.yaml', g_conf.PROCESS_NAME)[0] == "Finished": return full_dataset = os.path.join(os.environ["COIL_DATASET_PATH"], g_conf.TRAIN_DATASET_NAME) real_dataset = g_conf.TARGET_DOMAIN_PATH # real_dataset = os.path.join(os.environ["COIL_DATASET_PATH"], "FinalRealWorldDataset") #main data loader dataset = CoILDataset(full_dataset, transform=transforms.Compose([transforms.ToTensor() ])) sampler = BatchSequenceSampler( splitter.control_steer_split(dataset.measurements, dataset.meta_data), g_conf.BATCH_SIZE, g_conf.NUMBER_IMAGES_SEQUENCE, g_conf.SEQUENCE_STRIDE) data_loader = torch.utils.data.DataLoader(dataset, batch_sampler=sampler, shuffle=False, num_workers=6, pin_memory=True) real_dl = real_dataloader.RealDataset(real_dataset, g_conf.BATCH_SIZE) st = lambda aug: iag.Sometimes(aug, 0.4) oc = lambda aug: iag.Sometimes(aug, 0.3) rl = lambda aug: iag.Sometimes(aug, 0.09) augmenter = iag.Augmenter([iag.ToGPU()] + [ rl(iag.GaussianBlur( (0, 1.5))), # blur images with a sigma between 0 and 1.5 rl(iag.AdditiveGaussianNoise(loc=0, scale=( 0.0, 0.05), per_channel=0.5)), # add gaussian noise to images oc(iag.Dropout((0.0, 0.10), per_channel=0.5) ), # randomly remove up to X% of the pixels oc( iag.CoarseDropout( (0.0, 0.10), size_percent=(0.08, 0.2), per_channel=0.5)), # randomly remove up to X% of the pixels oc(iag.Add((-40, 40), per_channel=0.5) ), # change brightness of images (by -X to Y of original value) st(iag.Multiply((0.10, 2), per_channel=0.2) ), # change brightness of images (X-Y% of original value) rl(iag.ContrastNormalization( (0.5, 1.5), per_channel=0.5)), # improve or worsen the contrast rl(iag.Grayscale((0.0, 1))), # put grayscale ] # do all of the above in random order ) l1weight = g_conf.L1_WEIGHT task_adv_weight = g_conf.TASK_ADV_WEIGHT image_size = tuple([88, 200]) print(strftime("%Y-%m-%d %H:%M:%S", gmtime())) print("Configurations of ", exp_alias) print("GANMODEL_NAME", g_conf.GANMODEL_NAME) print("LOSS_FUNCTION", g_conf.LOSS_FUNCTION) print("LR_G, LR_D, LR", g_conf.LR_G, g_conf.LR_D, g_conf.LEARNING_RATE) print("SKIP", g_conf.SKIP) print("TYPE", g_conf.TYPE) print("L1 WEIGHT", g_conf.L1_WEIGHT) print("TASK ADV WEIGHT", g_conf.TASK_ADV_WEIGHT) print("LAB SMOOTH", g_conf.LABSMOOTH) if g_conf.GANMODEL_NAME == 'LSDcontrol': netD = ganmodels._netD(loss=g_conf.LOSS_FUNCTION).cuda() netG = ganmodels._netG(loss=g_conf.LOSS_FUNCTION, skip=g_conf.SKIP).cuda() elif g_conf.GANMODEL_NAME == 'LSDcontrol_nopatch': netD = ganmodels_nopatch._netD(loss=g_conf.LOSS_FUNCTION).cuda() netG = ganmodels_nopatch._netG(loss=g_conf.LOSS_FUNCTION).cuda() elif g_conf.GANMODEL_NAME == 'LSDcontrol_nopatch_smaller': netD = ganmodels_nopatch_smaller._netD( loss=g_conf.LOSS_FUNCTION).cuda() netG = ganmodels_nopatch_smaller._netG( loss=g_conf.LOSS_FUNCTION).cuda() elif g_conf.GANMODEL_NAME == 'LSDcontrol_task': netD = ganmodels_task._netD(loss=g_conf.LOSS_FUNCTION).cuda() netG = ganmodels_task._netG(loss=g_conf.LOSS_FUNCTION).cuda() netF = ganmodels_task._netF(loss=g_conf.LOSS_FUNCTION).cuda() if g_conf.PRETRAINED == 'RECON': netF_statedict = torch.load('netF_GAN_Pretrained.wts') netF.load_state_dict(netF_statedict) elif g_conf.PRETRAINED == 'IL': print("Loading IL") model_IL = torch.load('best_loss_20-06_EpicClearWeather.pth') model_IL_state_dict = model_IL['state_dict'] netF_state_dict = netF.state_dict() print(len(netF_state_dict.keys()), len(model_IL_state_dict.keys())) for i, keys in enumerate( zip(netF_state_dict.keys(), model_IL_state_dict.keys())): newkey, oldkey = keys # if newkey.split('.')[0] == "branch" and oldkey.split('.')[0] == "branches": # print("No Transfer of ", newkey, " to ", oldkey) # else: print("Transferring ", newkey, " to ", oldkey) netF_state_dict[newkey] = model_IL_state_dict[oldkey] netF.load_state_dict(netF_state_dict) print("IL Model Loaded!") elif g_conf.GANMODEL_NAME == 'LSDcontrol_task_2d': netD = ganmodels_taskAC_shared._netD().cuda() netG = ganmodels_taskAC_shared._netG().cuda() netF = ganmodels_taskAC_shared._netF().cuda() if g_conf.PRETRAINED == 'IL': print("Loading IL") model_IL = torch.load(g_conf.IL_AGENT_PATH) model_IL_state_dict = model_IL['state_dict'] netF_state_dict = netF.state_dict() print(len(netF_state_dict.keys()), len(model_IL_state_dict.keys())) for i, keys in enumerate( zip(netF_state_dict.keys(), model_IL_state_dict.keys())): newkey, oldkey = keys print("Transferring ", newkey, " to ", oldkey) netF_state_dict[newkey] = model_IL_state_dict[oldkey] netF.load_state_dict(netF_state_dict) print("IL Model Loaded!") ##### if g_conf.IF_AUG: print("Loading Aug Decoder") model_dec = torch.load(g_conf.DECODER_RECON_PATH) else: print("Loading Decoder") model_dec = torch.load(g_conf.DECODER_RECON_PATH) model_dec_state_dict = model_dec['stateG_dict'] netG_state_dict = netG.state_dict() print(len(netG_state_dict.keys()), len(model_dec_state_dict.keys())) for i, keys in enumerate( zip(netG_state_dict.keys(), model_dec_state_dict.keys())): newkey, oldkey = keys print("Transferring ", newkey, " to ", oldkey) netG_state_dict[newkey] = model_dec_state_dict[oldkey] netG.load_state_dict(netG_state_dict) print("Decoder Model Loaded!") init_weights(netD) print(netD) print(netF) print(netG) optimD = torch.optim.Adam(netD.parameters(), lr=g_conf.LR_D, betas=(0.5, 0.999)) optimG = torch.optim.Adam(netG.parameters(), lr=g_conf.LR_G, betas=(0.5, 0.999)) if g_conf.TYPE == 'task': optimF = torch.optim.Adam(netF.parameters(), lr=g_conf.LEARNING_RATE) Task_Loss = TaskLoss() if g_conf.GANMODEL_NAME == 'LSDcontrol_task_2d': print("Using cross entropy!") Loss = torch.nn.CrossEntropyLoss().cuda() L1_loss = torch.nn.L1Loss().cuda() iteration = 0 best_loss_iter_F = 0 best_loss_iter_G = 0 best_lossF = 1000000.0 best_lossD = 1000000.0 best_lossG = 1000000.0 accumulated_time = 0 n_critic = g_conf.N_CRITIC lossF = Variable(torch.Tensor([100.0])) lossG_adv = Variable(torch.Tensor([100.0])) lossG_smooth = Variable(torch.Tensor([100.0])) lossG = Variable(torch.Tensor([100.0])) netG.train() netD.train() netF.train() capture_time = time.time() if not os.path.exists('./imgs_' + exp_alias): os.mkdir('./imgs_' + exp_alias) fake_img_pool_src = ImagePool(50) fake_img_pool_tgt = ImagePool(50) for data in data_loader: set_requires_grad(netD, True) set_requires_grad(netF, True) set_requires_grad(netG, True) input_data, float_data = data tgt_imgs = real_dl.get_imgs() if g_conf.IF_AUG: inputs = augmenter(0, input_data['rgb']) else: inputs = input_data['rgb'].cuda() tgt_imgs = tgt_imgs.cuda() inputs = inputs.squeeze(1) inputs = inputs tgt_imgs = tgt_imgs controls = float_data[:, dataset.controls_position(), :] camera_angle = float_data[:, 26, :] camera_angle = camera_angle.cuda() steer = float_data[:, 0, :] steer = steer.cuda() speed = float_data[:, 10, :] speed = speed.cuda() time_use = 1.0 car_length = 3.0 extra_factor = 2.5 threshold = 1.0 pos = camera_angle > 0.0 pos = pos.type(torch.FloatTensor) neg = camera_angle <= 0.0 neg = neg.type(torch.FloatTensor) pos = pos.cuda() neg = neg.cuda() rad_camera_angle = math.pi * (torch.abs(camera_angle)) / 180.0 val = extra_factor * (torch.atan((rad_camera_angle * car_length) / (time_use * speed + 0.05))) / 3.1415 steer -= pos * torch.min(val, torch.Tensor([0.6]).cuda()) steer += neg * torch.min(val, torch.Tensor([0.6]).cuda()) steer = steer.cpu() float_data[:, 0, :] = steer float_data[:, 0, :][float_data[:, 0, :] > 1.0] = 1.0 float_data[:, 0, :][float_data[:, 0, :] < -1.0] = -1.0 src_embed_inputs, src_branches = netF( inputs, dataset.extract_inputs(float_data).cuda()) tgt_embed_inputs = netF(tgt_imgs, None) src_fake_inputs = netG(src_embed_inputs.detach()) tgt_fake_inputs = netG(tgt_embed_inputs.detach()) if iteration % 100 == 0: imgs_to_save = torch.cat((inputs[:1], src_fake_inputs[:1], tgt_imgs[:1], tgt_fake_inputs[:1]), 0).cpu().data coil_logger.add_image("Images", imgs_to_save, iteration) imgs_to_save = imgs_to_save.clamp(0.0, 1.0) vutils.save_image(imgs_to_save, './imgs_' + exp_alias + '/' + str(iteration) + '_real_and_fake.png', normalize=False) ##--------------------Discriminator part!!!!!!!!!!-------------------## ##source fake if g_conf.IF_POOL: src_fake_inputs_forD = fake_img_pool_src.query(src_fake_inputs) tgt_fake_inputs_forD = fake_img_pool_tgt.query(tgt_fake_inputs) else: src_fake_inputs_forD = src_fake_inputs tgt_fake_inputs_forD = tgt_fake_inputs set_requires_grad(netD, True) set_requires_grad(netF, False) set_requires_grad(netG, False) optimD.zero_grad() outputsD_fake_src_bin, __ = netD(src_fake_inputs_forD.detach()) outputsD_fake_tgt_bin, __ = netD(tgt_fake_inputs_forD.detach()) outputsD_real_src_bin, __ = netD(inputs) outputsD_real_tgt_bin, __ = netD(tgt_imgs) gradient_penalty_src = calc_gradient_penalty(netD, inputs, src_fake_inputs_forD, "recon") lossD_bin_src = torch.mean( outputsD_fake_src_bin - outputsD_real_src_bin) + gradient_penalty_src gradient_penalty_tgt = calc_gradient_penalty(netD, tgt_imgs, tgt_fake_inputs_forD, "recon") lossD_bin_tgt = torch.mean( outputsD_fake_tgt_bin - outputsD_real_tgt_bin) + gradient_penalty_tgt lossD = (lossD_bin_src + lossD_bin_tgt) * 0.5 lossD.backward(retain_graph=True) optimD.step() coil_logger.add_scalar('Total LossD Bin', lossD.data, iteration) coil_logger.add_scalar('Src LossD Bin', lossD_bin_src.data, iteration) coil_logger.add_scalar('Tgt LossD Bin', lossD_bin_tgt.data, iteration) ##--------------------Generator part!!!!!!!!!!-----------------------## set_requires_grad(netD, False) set_requires_grad(netF, False) set_requires_grad(netG, True) optimG.zero_grad() #fake outputs for bin outputsD_bin_src_fake_forG, __ = netD(src_fake_inputs) outputsD_bin_tgt_fake_forG, __ = netD(tgt_fake_inputs) #Generator updates if ((iteration + 1) % n_critic) == 0: #for netD_bin optimG.zero_grad() outputsD_bin_fake_forG = netD(tgt_imgs) #Generator updates lossG_src_smooth = L1_loss( src_fake_inputs, inputs) # L1 loss with real domain image lossG_tgt_smooth = L1_loss( tgt_fake_inputs, tgt_imgs) # L1 loss with real domain image lossG_src_adv_bin = -1.0 * torch.mean(outputsD_bin_src_fake_forG) lossG_tgt_adv_bin = -1.0 * torch.mean(outputsD_bin_tgt_fake_forG) lossG_adv_bin = 0.5 * (lossG_src_adv_bin + lossG_tgt_adv_bin) lossG_Adv = lossG_adv_bin lossG_L1 = 0.5 * (lossG_src_smooth + lossG_tgt_smooth) lossG = (lossG_Adv + l1weight * lossG_L1) / (1.0 + l1weight) lossG.backward(retain_graph=True) optimG.step() coil_logger.add_scalar('Total LossG', lossG.data, iteration) coil_logger.add_scalar('LossG Adv', lossG_Adv.data, iteration) coil_logger.add_scalar('Adv Bin LossG', lossG_adv_bin.data, iteration) coil_logger.add_scalar('Smooth LossG', lossG_L1.data, iteration) #####Task network updates########################## set_requires_grad(netD, False) set_requires_grad(netF, True) set_requires_grad(netG, False) optimF.zero_grad() lossF_task = Task_Loss.MSELoss( src_branches, dataset.extract_targets(float_data).cuda(), controls.cuda(), dataset.extract_inputs(float_data).cuda()) __, outputsD_fake_src_da = netD(src_fake_inputs_forD.detach()) __, outputsD_real_tgt_da = netD(tgt_imgs) __, outputsD_fake_tgt_da = netD(tgt_fake_inputs_forD.detach()) __, outputsD_real_src_da = netD(inputs) gradient_penalty_da_1 = calc_gradient_penalty( netD, tgt_imgs, src_fake_inputs_forD, "da") lossF_da_1 = torch.mean(outputsD_fake_src_da - outputsD_real_tgt_da ) + gradient_penalty_da_1 gradient_penalty_da_2 = calc_gradient_penalty( netD, inputs, tgt_fake_inputs_forD, "da") lossF_da_2 = torch.mean(outputsD_fake_tgt_da - outputsD_real_src_da ) + gradient_penalty_da_2 lossF_da = 0.5 * (lossF_da_1 + lossF_da_2) lossF = (lossF_task + task_adv_weight * lossF_da) / (1.0 + task_adv_weight) coil_logger.add_scalar('Total Task Loss', lossF.data, iteration) coil_logger.add_scalar('Adv Task Loss', lossF_da.data, iteration) coil_logger.add_scalar('Only Task Loss', lossF_task.data, iteration) lossF.backward(retain_graph=True) optimF.step() if lossG.data < best_lossG: best_lossG = lossG.data.tolist() best_loss_iter_G = iteration if lossF.data < best_lossF: best_lossF = lossF.data.tolist() best_loss_iter_F = iteration #optimization for one iter done! position = random.randint(0, len(float_data) - 1) if lossD.data < best_lossD: best_lossD = lossD.data.tolist() accumulated_time += time.time() - capture_time capture_time = time.time() if is_ready_to_save(iteration): state = { 'iteration': iteration, 'stateD_dict': netD.state_dict(), 'stateG_dict': netG.state_dict(), 'stateF_dict': netF.state_dict(), 'best_lossD': best_lossD, 'best_lossG': best_lossG, 'total_time': accumulated_time, 'best_loss_iter_G': best_loss_iter_G, 'best_loss_iter_F': best_loss_iter_F } torch.save( state, os.path.join('/datatmp/Datasets/rohitgan/_logs', exp_batch, exp_alias, 'checkpoints', str(iteration) + '.pth')) if iteration == best_loss_iter_F and iteration > 10000: state = { 'iteration': iteration, 'stateD_dict': netD.state_dict(), 'stateG_dict': netG.state_dict(), 'stateF_dict': netF.state_dict(), 'best_lossD': best_lossD, 'best_lossG': best_lossG, 'best_lossF': best_lossF, 'total_time': accumulated_time, 'best_loss_iter_F': best_loss_iter_F } torch.save( state, os.path.join('/datatmp/Datasets/rohitgan/_logs', exp_batch, exp_alias, 'best_modelF' + '.pth')) iteration += 1
def execute(gpu, exp_batch, exp_alias): os.environ["CUDA_VISIBLE_DEVICES"] = gpu merge_with_yaml(os.path.join('configs', exp_batch, exp_alias + '.yaml')) set_type_of_process('train') coil_logger.add_message('Loading', {'GPU': gpu}) if not os.path.exists('_output_logs'): os.mkdir('_output_logs') sys.stdout = open(os.path.join( '_output_logs', g_conf.PROCESS_NAME + '_' + str(os.getpid()) + ".out"), "a", buffering=1) if monitorer.get_status(exp_batch, exp_alias + '.yaml', g_conf.PROCESS_NAME)[0] == "Finished": return full_dataset = os.path.join(os.environ["COIL_DATASET_PATH"], g_conf.TRAIN_DATASET_NAME) dataset = CoILDataset(full_dataset, transform=transforms.Compose([transforms.ToTensor() ])) sampler = BatchSequenceSampler( splitter.control_steer_split(dataset.measurements, dataset.meta_data), g_conf.BATCH_SIZE, g_conf.NUMBER_IMAGES_SEQUENCE, g_conf.SEQUENCE_STRIDE) data_loader = torch.utils.data.DataLoader(dataset, batch_sampler=sampler, shuffle=False, num_workers=6, pin_memory=True) l1weight = g_conf.L1_WEIGHT image_size = tuple([88, 200]) if g_conf.TRAIN_TYPE == 'WGAN': clamp_value = g_conf.CLAMP n_critic = g_conf.N_CRITIC print("Configurations of ", exp_alias) print("GANMODEL_NAME", g_conf.GANMODEL_NAME) print("LOSS_FUNCTION", g_conf.LOSS_FUNCTION) print("LR_G, LR_D, LR", g_conf.LR_G, g_conf.LR_D, g_conf.LEARNING_RATE) print("SKIP", g_conf.SKIP) print("TYPE", g_conf.TYPE) print("L1 WEIGHT", g_conf.L1_WEIGHT) print("LAB SMOOTH", g_conf.LABSMOOTH) if g_conf.GANMODEL_NAME == 'LSDcontrol': netD = ganmodels._netD(loss=g_conf.LOSS_FUNCTION).cuda() netG = ganmodels._netG(loss=g_conf.LOSS_FUNCTION, skip=g_conf.SKIP).cuda() elif g_conf.GANMODEL_NAME == 'LSDcontrol_nopatch': netD = ganmodels_nopatch._netD(loss=g_conf.LOSS_FUNCTION).cuda() netG = ganmodels_nopatch._netG(loss=g_conf.LOSS_FUNCTION).cuda() elif g_conf.GANMODEL_NAME == 'LSDcontrol_nopatch_smaller': netD = ganmodels_nopatch_smaller._netD( loss=g_conf.LOSS_FUNCTION).cuda() netG = ganmodels_nopatch_smaller._netG( loss=g_conf.LOSS_FUNCTION).cuda() elif g_conf.GANMODEL_NAME == 'LSDcontrol_task': netD = ganmodels_task._netD(loss=g_conf.LOSS_FUNCTION).cuda() netG = ganmodels_task._netG(loss=g_conf.LOSS_FUNCTION).cuda() netF = ganmodels_task._netF(loss=g_conf.LOSS_FUNCTION).cuda() if g_conf.PRETRAINED == 'RECON': netF_statedict = torch.load('netF_GAN_Pretrained.wts') netF.load_state_dict(netF_statedict) elif g_conf.PRETRAINED == 'IL': model_IL = torch.load('best_loss_20-06_EpicClearWeather.pth') model_IL_state_dict = model_IL['state_dict'] netF_state_dict = netF.state_dict() for i, keys in enumerate( zip(netF_state_dict.keys(), model_IL_state_dict.keys())): newkey, oldkey = keys if newkey.split('.')[0] == "branch" and oldkey.split( '.')[0] == "branches": print("No Transfer of ", newkey, " to ", oldkey) else: print("Transferring ", newkey, " to ", oldkey) netF_state_dict[newkey] = model_IL_state_dict[oldkey] netF.load_state_dict(netF_state_dict) init_weights(netD) init_weights(netG) #do init for netF also later but now it is in the model code itself print(netD) print(netF) print(netG) optimD = torch.optim.Adam(netD.parameters(), lr=g_conf.LR_D, betas=(0.5, 0.999)) optimG = torch.optim.Adam(netG.parameters(), lr=g_conf.LR_G, betas=(0.5, 0.999)) if g_conf.TYPE == 'task': optimF = torch.optim.Adam(netF.parameters(), lr=g_conf.LEARNING_RATE) Task_Loss = TaskLoss() if g_conf.LOSS_FUNCTION == 'LSGAN': Loss = torch.nn.MSELoss().cuda() elif g_conf.LOSS_FUNCTION == 'NORMAL': Loss = torch.nn.BCEWithLogitsLoss().cuda() L1_loss = torch.nn.L1Loss().cuda() iteration = 0 best_loss_iter_F = 0 best_loss_iter_G = 0 best_lossF = 1000000.0 best_lossD = 1000000.0 best_lossG = 1000000.0 accumulated_time = 0 lossF = Variable(torch.Tensor([100.0])) lossG_adv = Variable(torch.Tensor([100.0])) lossG_smooth = Variable(torch.Tensor([100.0])) lossG = Variable(torch.Tensor([100.0])) netG.train() netD.train() netF.train() capture_time = time.time() if not os.path.exists('./imgs_' + exp_alias): os.mkdir('./imgs_' + exp_alias) #TODO put family for losses fake_img_pool = ImagePool(50) for data in data_loader: set_requires_grad(netD, True) set_requires_grad(netF, True) set_requires_grad(netG, True) # print("ITERATION:", iteration) val = 0.0 input_data, float_data = data inputs = input_data['rgb'].cuda() inputs = inputs.squeeze(1) inputs_in = inputs - val #subtracted by 0.5 #TODO: make sure the F network does not get optimized by G optim controls = float_data[:, dataset.controls_position(), :] embed, branches = netF(inputs_in, dataset.extract_inputs(float_data).cuda()) print("Branch Outputs:::", branches[0][0]) embed_inputs = embed fake_inputs = netG(embed_inputs) fake_inputs_in = fake_inputs if iteration % 500 == 0: imgs_to_save = torch.cat( (inputs_in[:2] + val, fake_inputs_in[:2] + val), 0).cpu().data vutils.save_image(imgs_to_save, './imgs_' + exp_alias + '/' + str(iteration) + '_real_and_fake.png', normalize=True) coil_logger.add_image("Images", imgs_to_save, iteration) ##--------------------Discriminator part!!!!!!!!!!-------------------## set_requires_grad(netD, True) set_requires_grad(netF, False) set_requires_grad(netG, False) optimD.zero_grad() ##fake # fake_inputs_forD = fake_img_pool.query(fake_inputs) outputsD_fake_forD = netD(fake_inputs) labsize = outputsD_fake_forD.size() labels_fake = torch.zeros(labsize) #Fake labels label_fake_noise = torch.rand( labels_fake.size()) * 0.1 #Label smoothing if g_conf.LABSMOOTH == 1: labels_fake = labels_fake + label_fake_noise labels_fake = Variable(labels_fake).cuda() lossD_fake = torch.mean( outputsD_fake_forD) #Loss(outputsD_fake_forD, labels_fake) ##real outputsD_real = netD(inputs_in) labsize = outputsD_real.size() labels_real = torch.ones(labsize) #Real labels label_real_noise = torch.rand( labels_real.size()) * 0.1 #Label smoothing if g_conf.LABSMOOTH == 1: labels_real = labels_real - label_real_noise labels_real = Variable(labels_real).cuda() lossD_real = -1.0 * torch.mean( outputsD_real) #Loss(outputsD_real, labels_real) ### Gradient Penalty ### gradient_penalty = calc_gradient_penalty(netD, inputs, fake_inputs) # alpha = torch.rand((g_conf.BATCH_SIZE, 1, 1, 1)) # alpha = alpha.cuda() # # x_hat = alpha * inputs.data + (1 - alpha) * fake_inputs.data # x_hat.requires_grad = True # # pred_hat = netD(x_hat) # gradients = grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()).cuda(), # create_graph=True, retain_graph=True, only_inputs=True)[0] # # gradient_penalty = 10 * ((gradients.view(gradients.size()[0], -1).norm(2, 1) - 1) ** 2).mean() #Discriminator updates lossD = torch.mean( outputsD_fake_forD - outputsD_real) + gradient_penalty #(lossD_real + lossD_fake) * 0.5 # lossD /= len(inputs) print("Loss d", lossD) lossD.backward(retain_graph=True) optimD.step() # if g_conf.TRAIN_TYPE == 'WGAN': # for p in netD.parameters(): # p.data.clamp_(-clamp_value, clamp_value) coil_logger.add_scalar('Total LossD', lossD.data, iteration) coil_logger.add_scalar('Real LossD', lossD_real.data / len(inputs), iteration) coil_logger.add_scalar('Fake LossD', lossD_fake.data / len(inputs), iteration) ##--------------------Generator part!!!!!!!!!!-----------------------## set_requires_grad(netD, False) set_requires_grad(netF, False) set_requires_grad(netG, True) if ((iteration + 1) % n_critic) == 0: optimG.zero_grad() outputsD_fake_forG = netD(fake_inputs) #Generator updates lossG_adv = -1.0 * torch.mean( outputsD_fake_forG) #Loss(outputsD_fake_forG, labels_real) lossG_smooth = L1_loss(fake_inputs, inputs) lossG = (lossG_adv + l1weight * lossG_smooth) / (1.0 + l1weight) # lossG /= len(inputs) print(lossG) lossG.backward(retain_graph=True) optimG.step() #####Task network updates########################## set_requires_grad(netD, False) set_requires_grad(netF, True) set_requires_grad(netG, False) optimF.zero_grad() lossF = Variable(torch.Tensor()) lossF = Task_Loss.MSELoss( branches, dataset.extract_targets(float_data).cuda(), controls.cuda(), dataset.extract_inputs(float_data).cuda()) coil_logger.add_scalar('Task Loss', lossF.data, iteration) lossF.backward() optimF.step() coil_logger.add_scalar('Total LossG', lossG.data, iteration) coil_logger.add_scalar('Adv LossG', lossG_adv.data / len(inputs), iteration) coil_logger.add_scalar('Smooth LossG', lossG_smooth.data / len(inputs), iteration) #optimization for one iter done! position = random.randint(0, len(float_data) - 1) if lossD.data < best_lossD: best_lossD = lossD.data.tolist() # print (lossG.item(), best_lossG) if lossG.item() < best_lossG: best_lossG = lossG.item() best_loss_iter_G = iteration if lossF.item() < best_lossF: best_lossF = lossF.item() best_loss_iter_F = iteration accumulated_time += time.time() - capture_time capture_time = time.time() print("LossD", lossD.data.tolist(), "LossG", lossG.data.tolist(), "BestLossD", best_lossD, "BestLossG", best_lossG, "LossF", lossF, "BestLossF", best_lossF, "Iteration", iteration, "Best Loss Iteration G", best_loss_iter_G, "Best Loss Iteration F", best_loss_iter_F) coil_logger.add_message( 'Iterating', { 'Iteration': iteration, 'LossD': lossD.data.tolist(), 'LossG': lossG.data.tolist(), 'Images/s': (iteration * g_conf.BATCH_SIZE) / accumulated_time, 'BestLossD': best_lossD, 'BestLossG': best_lossG, 'BestLossIterationG': best_loss_iter_G, 'BestLossF': best_lossF, 'BestLossIterationF': best_loss_iter_F, 'GroundTruth': dataset.extract_targets(float_data)[position].data.tolist(), 'Inputs': dataset.extract_inputs(float_data)[position].data.tolist() }, iteration) if is_ready_to_save(iteration): state = { 'iteration': iteration, 'stateD_dict': netD.state_dict(), 'stateG_dict': netG.state_dict(), 'stateF_dict': netF.state_dict(), 'best_lossD': best_lossD, 'best_lossG': best_lossG, 'total_time': accumulated_time, 'best_loss_iter_G': best_loss_iter_G, 'best_loss_iter_F': best_loss_iter_F } torch.save( state, os.path.join('/datatmp/Experiments/rohitgan/_logs', exp_batch, exp_alias, 'checkpoints', str(iteration) + '.pth')) if iteration == best_loss_iter_G and iteration > 10000: state = { 'iteration': iteration, 'stateD_dict': netD.state_dict(), 'stateG_dict': netG.state_dict(), 'stateF_dict': netF.state_dict(), 'best_lossD': best_lossD, 'best_lossG': best_lossG, 'total_time': accumulated_time, 'best_loss_iter_G': best_loss_iter_G } torch.save( state, os.path.join('/datatmp/Experiments/rohitgan/_logs', exp_batch, exp_alias, 'best_modelG' + '.pth')) if iteration == best_loss_iter_F and iteration > 10000: state = { 'iteration': iteration, 'stateD_dict': netD.state_dict(), 'stateG_dict': netG.state_dict(), 'stateF_dict': netF.state_dict(), 'best_lossD': best_lossD, 'best_lossG': best_lossG, 'best_lossF': best_lossF, 'total_time': accumulated_time, 'best_loss_iter_F': best_loss_iter_F } torch.save( state, os.path.join('/datatmp/Experiments/rohitgan/_logs', exp_batch, exp_alias, 'best_modelF' + '.pth')) iteration += 1
from input.coil_dataset_onlyil import CoILDataset import network.models.coil_ganmodules_taskAC as ganmodels_taskAC from network.loss_task import TaskLoss from torchvision import transforms os.environ["CUDA_VISIBLE_DEVICES"] = '0' full_dataset = '/datatmp/Datasets/JulyRohitRishabh/EpicWeather12_60k_June21_Straight+Turn/SeqVal' dataset = CoILDataset(full_dataset, transform=transforms.Compose([transforms.ToTensor()])) data_loader = torch.utils.data.DataLoader(dataset, batch_size=50, shuffle=False, num_workers=12, pin_memory=True) ckpts = glob.glob('/datatmp/Experiments/rohitgan/_logs/eccv/all_da_aug_orig_E1E12/checkpoints/*.pth') netF = ganmodels_taskAC._netF().cuda() Task_Loss = TaskLoss() best_loss = 1000 best_loss_ckpt = "none" for ckpt in ckpts: iter = 0 current_loss = 0 current_loss_total = 0 print(ckpt) model_IL = torch.load(ckpt) model_IL_state_dict = model_IL['stateF_dict'] netF.load_state_dict(model_IL_state_dict)
def execute(gpu, exp_batch, exp_alias): os.environ["CUDA_VISIBLE_DEVICES"] = gpu merge_with_yaml(os.path.join('configs', exp_batch, exp_alias + '.yaml')) set_type_of_process('train') coil_logger.add_message('Loading', {'GPU': gpu}) if not os.path.exists('_output_logs'): os.mkdir('_output_logs') sys.stdout = open(os.path.join( '_output_logs', g_conf.PROCESS_NAME + '_' + str(os.getpid()) + ".out"), "a", buffering=1) if monitorer.get_status(exp_batch, exp_alias + '.yaml', g_conf.PROCESS_NAME)[0] == "Finished": return full_dataset = os.path.join(os.environ["COIL_DATASET_PATH"], g_conf.TRAIN_DATASET_NAME) dataset = CoILDataset(full_dataset, transform=transforms.Compose([transforms.ToTensor() ])) sampler = BatchSequenceSampler( splitter.control_steer_split(dataset.measurements, dataset.meta_data), g_conf.BATCH_SIZE, g_conf.NUMBER_IMAGES_SEQUENCE, g_conf.SEQUENCE_STRIDE) data_loader = torch.utils.data.DataLoader(dataset, batch_sampler=sampler, shuffle=False, num_workers=6, pin_memory=True) l1weight = g_conf.L1_WEIGHT image_size = tuple([88, 200]) print("Configurations of ", exp_alias) print("GANMODEL_NAME", g_conf.GANMODEL_NAME) print("LOSS_FUNCTION", g_conf.LOSS_FUNCTION) print("LR_G, LR_D, LR", g_conf.LR_G, g_conf.LR_D, g_conf.LEARNING_RATE) print("SKIP", g_conf.SKIP) print("TYPE", g_conf.TYPE) print("L1 WEIGHT", g_conf.L1_WEIGHT) print("LAB SMOOTH", g_conf.LABSMOOTH) if g_conf.GANMODEL_NAME == 'LSDcontrol': netD = ganmodels._netD(loss=g_conf.LOSS_FUNCTION).cuda() netG = ganmodels._netG(loss=g_conf.LOSS_FUNCTION, skip=g_conf.SKIP).cuda() elif g_conf.GANMODEL_NAME == 'LSDcontrol_nopatch': netD = ganmodels_nopatch._netD(loss=g_conf.LOSS_FUNCTION).cuda() netG = ganmodels_nopatch._netG(loss=g_conf.LOSS_FUNCTION).cuda() elif g_conf.GANMODEL_NAME == 'LSDcontrol_nopatch_smaller': netD = ganmodels_nopatch_smaller._netD( loss=g_conf.LOSS_FUNCTION).cuda() netG = ganmodels_nopatch_smaller._netG( loss=g_conf.LOSS_FUNCTION).cuda() elif g_conf.GANMODEL_NAME == 'LSDcontrol_task': netD = ganmodels_task._netD(loss=g_conf.LOSS_FUNCTION).cuda() netG = ganmodels_task._netG(loss=g_conf.LOSS_FUNCTION).cuda() netF = ganmodels_task._netG(loss=g_conf.LOSS_FUNCTION).cuda() init_weights(netD) init_weights(netG) print(netD) print(netG) optimD = torch.optim.Adam(netD.parameters(), lr=g_conf.LR_D, betas=(0.7, 0.999)) optimG = torch.optim.Adam(netG.parameters(), lr=g_conf.LR_G, betas=(0.7, 0.999)) if g_conf.TYPE == 'task': optimF = torch.optim.Adam(netG.parameters(), lr=g_conf.LEARNING_RATE, betas=(0.7, 0.999)) Task_Loss = TaskLoss() if g_conf.LOSS_FUNCTION == 'LSGAN': Loss = torch.nn.MSELoss().cuda() elif g_conf.LOSS_FUNCTION == 'NORMAL': Loss = torch.nn.BCELoss().cuda() L1_loss = torch.nn.L1Loss().cuda() iteration = 0 best_loss_iter = 0 best_lossD = 1000000.0 best_lossG = 1000000.0 accumulated_time = 0 netG.train() netD.train() capture_time = time.time() if not os.path.exists('./imgs_' + exp_alias): os.mkdir('./imgs_' + exp_alias) #TODO add image queue #TODO add auxiliary regression loss for steering #TODO put family for losses fake_img_pool = ImagePool(50) for data in data_loader: val = 0.5 input_data, float_data = data inputs = input_data['rgb'].cuda() inputs = inputs.squeeze(1) inputs_in = inputs - val fake_inputs = netG(inputs_in) #subtracted by 0.5 fake_inputs_in = fake_inputs if iteration % 200 == 0: imgs_to_save = torch.cat((inputs_in[:2] + val, fake_inputs_in[:2]), 0).cpu().data vutils.save_image(imgs_to_save, './imgs_' + exp_alias + '/' + str(iteration) + '_real_and_fake.png', normalize=True) coil_logger.add_image("Images", imgs_to_save, iteration) ##--------------------Discriminator part!!!!!!!!!!-------------------## set_requires_grad(netD, True) optimD.zero_grad() ##fake fake_inputs_forD = fake_img_pool.query(fake_inputs) outputsD_fake_forD = netD(fake_inputs_forD.detach()) labsize = outputsD_fake_forD.size() labels_fake = torch.zeros(labsize) #Fake labels label_fake_noise = torch.rand( labels_fake.size()) * 0.05 - 0.025 #Label smoothing if g_conf.LABSMOOTH == 1: labels_fake = labels_fake + labels_fake_noise labels_fake = Variable(labels_fake).cuda() lossD_fake = Loss(outputsD_fake_forD, labels_fake) ##real outputsD_real = netD(inputs) print("some d outputs", outputsD_real[0]) labsize = outputsD_real.size() labels_real = torch.ones(labsize) #Real labels label_real_noise = torch.rand( labels_real.size()) * 0.05 - 0.025 #Label smoothing if g_conf.LABSMOOTH == 1: labels_real = labels_real + labels_real_noise labels_real = Variable(labels_real).cuda() lossD_real = Loss(outputsD_real, labels_real) #Discriminator updates lossD = (lossD_real + lossD_fake) * 0.5 # lossD /= len(inputs) lossD.backward() optimD.step() coil_logger.add_scalar('Total LossD', lossD.data, iteration) coil_logger.add_scalar('Real LossD', lossD_real.data, iteration) coil_logger.add_scalar('Fake LossD', lossD_fake.data, iteration) ##--------------------Generator part!!!!!!!!!!----------------------- set_requires_grad(netD, False) optimG.zero_grad() outputsD_fake_forG = netD(fake_inputs) #Generator updates lossG_adv = Loss(outputsD_fake_forG, labels_real) lossG_smooth = L1_loss(fake_inputs, inputs) lossG = (lossG_adv + l1weight * lossG_smooth) / (1.0 + l1weight) lossG lossG.backward() optimG.step() coil_logger.add_scalar('Total LossG', lossG.data, iteration) coil_logger.add_scalar('Adv LossG', lossG_adv.data, iteration) coil_logger.add_scalar('Smooth LossG', lossG_smooth.data, iteration) #optimization for one iter done! position = random.randint(0, len(float_data) - 1) if lossD.data < best_lossD: best_lossD = lossD.data.tolist() if lossG.data < best_lossG: best_lossG = lossG.data.tolist() best_loss_iter = iteration accumulated_time += time.time() - capture_time capture_time = time.time() print("LossD", lossD.data.tolist(), "LossG", lossG.data.tolist(), "BestLossD", best_lossD, "BestLossG", best_lossG, "Iteration", iteration, "Best Loss Iteration", best_loss_iter) coil_logger.add_message( 'Iterating', { 'Iteration': iteration, 'LossD': lossD.data.tolist(), 'LossG': lossG.data.tolist(), 'Images/s': (iteration * g_conf.BATCH_SIZE) / accumulated_time, 'BestLossD': best_lossD, 'BestLossIteration': best_loss_iter, 'BestLossG': best_lossG, 'BestLossIteration': best_loss_iter, 'GroundTruth': dataset.extract_targets(float_data)[position].data.tolist(), 'Inputs': dataset.extract_inputs(float_data)[position].data.tolist() }, iteration) if is_ready_to_save(iteration): state = { 'iteration': iteration, 'stateD_dict': netD.state_dict(), 'stateG_dict': netG.state_dict(), 'best_lossD': best_lossD, 'best_lossG': best_lossG, 'total_time': accumulated_time, 'best_loss_iter': best_loss_iter } torch.save( state, os.path.join('/datatmp/Experiments/rohitgan/_logs', exp_batch, exp_alias, 'checkpoints', str(iteration) + '.pth')) if iteration == best_loss_iter: state = { 'iteration': iteration, 'stateD_dict': netD.state_dict(), 'stateG_dict': netG.state_dict(), 'best_lossD': best_lossD, 'best_lossG': best_lossG, 'total_time': accumulated_time, 'best_loss_iter': best_loss_iter } torch.save( state, os.path.join('/datatmp/Experiments/rohitgan/_logs', exp_batch, exp_alias, 'best_modelG' + '.pth')) iteration += 1
def __init__(self, hyperparameters): super(UNIT_Trainer, self).__init__() lr = hyperparameters['lr'] lr_task = hyperparameters['lr_task'] # task part self.netF = _netF().cuda() netF.train() self.Task_Loss = TaskLoss() # Initiate the networks self.gen_a = VAEGen( hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = VAEGen( hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis( hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list( self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) task_params = list(self.netF.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.task_opt = torch.optim.Adam( [p for p in task_params if p.requires_grad], lr=lr_task, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False
class UNIT_Trainer(nn.Module): def __init__(self, hyperparameters): super(UNIT_Trainer, self).__init__() lr = hyperparameters['lr'] lr_task = hyperparameters['lr_task'] # task part self.netF = _netF().cuda() netF.train() self.Task_Loss = TaskLoss() # Initiate the networks self.gen_a = VAEGen( hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a self.gen_b = VAEGen( hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b self.dis_a = MsImageDis( hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) # Setup the optimizers beta1 = hyperparameters['beta1'] beta2 = hyperparameters['beta2'] dis_params = list(self.dis_a.parameters()) + list( self.dis_b.parameters()) gen_params = list(self.gen_a.parameters()) + list( self.gen_b.parameters()) task_params = list(self.netF.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.task_opt = torch.optim.Adam( [p for p in task_params if p.requires_grad], lr=lr_task, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters['init'])) self.dis_a.apply(weights_init('gaussian')) self.dis_b.apply(weights_init('gaussian')) # Load VGG model if needed if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) def forward(self, x_a, x_b): self.eval() x_a.volatile = True x_b.volatile = True h_a, _ = self.gen_a.encode(x_a) h_b, _ = self.gen_b.encode(x_b) x_ba = self.gen_a.decode(h_b) x_ab = self.gen_b.decode(h_a) self.train() return x_ab, x_ba def __compute_kl(self, mu): # def _compute_kl(self, mu, sd): # mu_2 = torch.pow(mu, 2) # sd_2 = torch.pow(sd, 2) # encoding_loss = (mu_2 + sd_2 - torch.log(sd_2)).sum() / mu_2.size(0) # return encoding_loss mu_2 = torch.pow(mu, 2) encoding_loss = torch.mean(mu_2) return encoding_loss def gen_update(self, x_a, x_b, float_data, hyperparameters): self.gen_opt.zero_grad() self.task_opt.zero_grad() # init data full_dataset = hyperparameters['train_dataset_name'] real_dataset = hyperparameters['target_domain_path'] dataset = CoILDataset(full_dataset, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])) # encode h_a, n_a = self.gen_a.encode(x_a) h_b, n_b = self.gen_b.encode(x_b) # decode (within domain) x_a_recon = self.gen_a.decode(h_a + n_a) x_b_recon = self.gen_b.decode(h_b + n_b) # decode (cross domain) x_ba = self.gen_a.decode(h_b + n_b) x_ab = self.gen_b.decode(h_a + n_a) # encode again h_b_recon, n_b_recon = self.gen_a.encode(x_ba) h_a_recon, n_a_recon = self.gen_b.encode(x_ab) # decode again (if needed) x_aba = self.gen_a.decode( h_a_recon + n_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None x_bab = self.gen_b.decode( h_b_recon + n_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None # #task part identity_embed = h_a cycle_embed = h_a_recon identity_task = self.netF( identity_embed, Variable(dataset.extract_inputs(float_data)).cuda()) cycle_task = self.netF( cycle_embed, Variable(dataset.extract_inputs(float_data)).cuda()) controls = Variable(float_data[:, dataset.controls_position(), :]) # task loss self.lossF_identity_task = self.Task_Loss.MSELoss( identity_task, Variable(dataset.extract_targets(float_data)).cuda(), controls.cuda(), Variable(dataset.extract_inputs(float_data)).cuda()) self.lossF_cycle_task = self.Task_Loss.MSELoss( cycle_task, Variable(dataset.extract_targets(float_data)).cuda(), controls.cuda(), Variable(dataset.extract_inputs(float_data)).cuda()) self.lossF_task = self.lossF_identity_task + self.lossF_cycle_task # reconstruction loss # print(x_a_recon[0][0][:5][:5]) # print("Help loss:", self.recon_criterion(x_a_recon, x_a)) # print("identity task", identity_task[0]) # print("cycle task", cycle_task[0]) self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) self.loss_gen_recon_kl_a = self.__compute_kl(h_a) self.loss_gen_recon_kl_b = self.__compute_kl(h_b) self.loss_gen_cyc_x_a = self.recon_criterion(x_aba, x_a) self.loss_gen_cyc_x_b = self.recon_criterion(x_bab, x_b) self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon) self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon) # GAN loss self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss( self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss( self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \ hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \ hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \ hyperparameters['task_w'] * self.lossF_task self.loss_gen_total.backward() self.gen_opt.step() self.task_opt.zero_grad() identity_task = self.netF( identity_embed, Variable(dataset.extract_inputs(float_data)).cuda()) cycle_task = self.netF( cycle_embed, Variable(dataset.extract_inputs(float_data)).cuda()) controls = Variable(float_data[:, dataset.controls_position(), :]) # task loss self.lossF_identity_task = self.Task_Loss.MSELoss( identity_task, Variable(dataset.extract_targets(float_data)).cuda(), controls.cuda(), Variable(dataset.extract_inputs(float_data)).cuda()) self.lossF_cycle_task = self.Task_Loss.MSELoss( cycle_task, Variable(dataset.extract_targets(float_data)).cuda(), controls.cuda(), Variable(dataset.extract_inputs(float_data)).cuda()) self.lossF_task = self.lossF_identity_task + self.lossF_cycle_task self.task_opt.step() def compute_vgg_loss(self, vgg, img, target): img_vgg = vgg_preprocess(img) target_vgg = vgg_preprocess(target) img_fea = vgg(img_vgg) target_fea = vgg(target_vgg) return torch.mean( (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2) def sample(self, x_a, x_b): self.eval() x_a.volatile = True x_b.volatile = True x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], [] for i in range(x_a.size(0)): h_a, _ = self.gen_a.encode(x_a[i].unsqueeze(0)) h_b, _ = self.gen_b.encode(x_b[i].unsqueeze(0)) x_a_recon.append(self.gen_a.decode(h_a)) x_b_recon.append(self.gen_b.decode(h_b)) x_ba.append(self.gen_a.decode(h_b)) x_ab.append(self.gen_b.decode(h_a)) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) x_ba = torch.cat(x_ba) x_ab = torch.cat(x_ab) self.train() return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba def dis_update(self, x_a, x_b, hyperparameters): self.dis_opt.zero_grad() # encode h_a, n_a = self.gen_a.encode(x_a) h_b, n_b = self.gen_b.encode(x_b) # decode (cross domain) x_ba = self.gen_a.decode(h_b + n_b) x_ab = self.gen_b.decode(h_a + n_a) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) self.loss_dis_total = hyperparameters[ 'gan_w'] * self.loss_dis_a + hyperparameters[ 'gan_w'] * self.loss_dis_b self.loss_dis_total.backward() self.dis_opt.step() def update_learning_rate(self): if self.dis_scheduler is not None: self.dis_scheduler.step() if self.gen_scheduler is not None: self.gen_scheduler.step() def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations def save(self, snapshot_dir, iterations): # Save generators, discriminators, and optimizers gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({ 'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict() }, gen_name) torch.save({ 'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict() }, dis_name) torch.save( { 'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict() }, opt_name)
def execute(gpu, exp_batch, exp_alias): manualSeed = g_conf.SEED torch.cuda.manual_seed(manualSeed) os.environ["CUDA_VISIBLE_DEVICES"] = gpu merge_with_yaml(os.path.join('configs', exp_batch, exp_alias + '.yaml')) set_type_of_process('train') coil_logger.add_message('Loading', {'GPU': gpu}) if not os.path.exists('_output_logs'): os.mkdir('_output_logs') sys.stdout = open(os.path.join( '_output_logs', g_conf.PROCESS_NAME + '_' + str(os.getpid()) + ".out"), "a", buffering=1) if monitorer.get_status(exp_batch, exp_alias + '.yaml', g_conf.PROCESS_NAME)[0] == "Finished": return full_dataset = os.path.join(os.environ["COIL_DATASET_PATH"], g_conf.TRAIN_DATASET_NAME) real_dataset = os.path.join(os.environ["COIL_DATASET_PATH"], "FinalRealWorldDataset") #main data loader dataset = CoILDataset(full_dataset, real_dataset, transform=transforms.Compose([transforms.ToTensor() ])) sampler = BatchSequenceSampler( splitter.control_steer_split(dataset.measurements, dataset.meta_data), g_conf.BATCH_SIZE, g_conf.NUMBER_IMAGES_SEQUENCE, g_conf.SEQUENCE_STRIDE) data_loader = torch.utils.data.DataLoader(dataset, batch_sampler=sampler, shuffle=False, num_workers=6, pin_memory=True) #real image dataloader l1weight = g_conf.L1_WEIGHT image_size = tuple([88, 200]) print("Configurations of ", exp_alias) print("GANMODEL_NAME", g_conf.GANMODEL_NAME) print("LOSS_FUNCTION", g_conf.LOSS_FUNCTION) print("TYPE", g_conf.TYPE) print("L1 WEIGHT", g_conf.L1_WEIGHT) optimD = torch.optim.Adam(netD.parameters(), lr=g_conf.LR_D, betas=(0.7, 0.999)) optimG = torch.optim.Adam(netG.parameters(), lr=g_conf.LR_G, betas=(0.7, 0.999)) if g_conf.TYPE == 'task': optimF = torch.optim.Adam(netF.parameters(), lr=g_conf.LEARNING_RATE) Task_Loss = TaskLoss() if g_conf.LOSS_FUNCTION == 'LSGAN': Loss = torch.nn.MSELoss().cuda() elif g_conf.LOSS_FUNCTION == 'NORMAL': Loss = torch.nn.BCELoss().cuda() L1_loss = torch.nn.L1Loss().cuda() iteration = 0 best_loss_iter_F = 0 best_loss_iter_G = 0 best_lossF = 1000000.0 best_lossD = 1000000.0 best_lossG = 1000000.0 accumulated_time = 0 gen_iterations = 0 netG.train() netD.train() netF.train() capture_time = time.time() if not os.path.exists('./imgs_' + exp_alias): os.mkdir('./imgs_' + exp_alias) #TODO put family for losses fake_img_pool = ImagePool(50) for data in data_loader: set_requires_grad(netD, True) set_requires_grad(netF, True) set_requires_grad(netG, True) input_data, float_data, tgt_imgs = data if g_conf.IF_AUG: inputs = augmenter(0, input_data['rgb']) # tgt_imgs = augmenter(0, tgt_imgs) else: inputs = input_data['rgb'].cuda() # tgt_imgs = tgt_imgs.cuda() tgt_imgs = tgt_imgs.cuda() #TODO: make sure the F network does not get optimized by G optim controls = float_data[:, dataset.controls_position(), :] embed, branches = netF(inputs_in, dataset.extract_inputs(float_data).cuda()) print("Branch Outputs:::", branches[0][0]) embed_inputs = embed fake_inputs = netG(embed_inputs.detach()) fake_inputs_in = fake_inputs if iteration % 500 == 0: imgs_to_save = torch.cat((inputs_in[:2] + val, fake_inputs_in[:2]), 0).cpu().data vutils.save_image(imgs_to_save, './imgs_' + exp_alias + '/' + str(iteration) + '_real_and_fake.png', normalize=True) coil_logger.add_image("Images", imgs_to_save, iteration) ##--------------------Discriminator part!!!!!!!!!!-------------------## set_requires_grad(netD, True) set_requires_grad(netF, False) set_requires_grad(netG, False) optimD.zero_grad() ##fake fake_inputs_forD = fake_img_pool.query(fake_inputs.detach()) outputsD_fake_forD = netD(fake_inputs_forD.detach()) labsize = outputsD_fake_forD.size() labels_fake = torch.zeros(labsize) #Fake labels label_fake_noise = torch.rand( labels_fake.size()) * 0.05 - 0.025 #Label smoothing if g_conf.LABSMOOTH == 1: labels_fake = labels_fake + labels_fake_noise labels_fake = Variable(labels_fake).cuda() lossD_fake = Loss(outputsD_fake_forD, labels_fake) ##real outputsD_real = netD(inputs) labsize = outputsD_real.size() labels_real = torch.ones(labsize) #Real labels label_real_noise = torch.rand( labels_real.size()) * 0.05 - 0.025 #Label smoothing if g_conf.LABSMOOTH == 1: labels_real = labels_real + labels_real_noise labels_real = Variable(labels_real).cuda() lossD_real = Loss(outputsD_real, labels_real) #Discriminator updates lossD = (lossD_real + lossD_fake) * 0.5 lossD /= len(inputs) lossD.backward() optimD.step() coil_logger.add_scalar('Total LossD', lossD.data, iteration) coil_logger.add_scalar('Real LossD', lossD_real.data / len(inputs), iteration) coil_logger.add_scalar('Fake LossD', lossD_fake.data / len(inputs), iteration) ##--------------------Generator part!!!!!!!!!!-----------------------## set_requires_grad(netD, False) set_requires_grad(netF, False) set_requires_grad(netG, True) optimG.zero_grad() outputsD_fake_forG = netD(fake_inputs) #Generator updates lossG_adv = Loss(outputsD_fake_forG, labels_real) lossG_smooth = L1_loss(fake_inputs, inputs) lossG = (lossG_adv + l1weight * lossG_smooth) / (1.0 + l1weight) lossG /= len(inputs) print(lossG) lossG.backward() optimG.step() #####Task network updates########################## set_requires_grad(netD, False) set_requires_grad(netF, True) set_requires_grad(netG, False) optimF.zero_grad() lossF = Task_Loss.MSELoss(branches, dataset.extract_targets(float_data).cuda(), controls.cuda(), dataset.extract_inputs(float_data).cuda()) coil_logger.add_scalar('Task Loss', lossF.data, iteration) lossF.backward() optimF.step() coil_logger.add_scalar('Total LossG', lossG.data, iteration) coil_logger.add_scalar('Adv LossG', lossG_adv.data / len(inputs), iteration) coil_logger.add_scalar('Smooth LossG', lossG_smooth.data / len(inputs), iteration) #optimization for one iter done! position = random.randint(0, len(float_data) - 1) if lossD.data < best_lossD: best_lossD = lossD.data.tolist() if lossG.data < best_lossG: best_lossG = lossG.data.tolist() best_loss_iter_G = iteration if lossF.data < best_lossF: best_lossF = lossF.data.tolist() best_loss_iter_F = iteration accumulated_time += time.time() - capture_time capture_time = time.time() print("LossD", lossD.data.tolist(), "LossG", lossG.data.tolist(), "BestLossD", best_lossD, "BestLossG", best_lossG, "LossF", lossF, "BestLossF", best_lossF, "Iteration", iteration, "Best Loss Iteration G", best_loss_iter_G, "Best Loss Iteration F", best_loss_iter_F) coil_logger.add_message( 'Iterating', { 'Iteration': iteration, 'LossD': lossD.data.tolist(), 'LossG': lossG.data.tolist(), 'Images/s': (iteration * g_conf.BATCH_SIZE) / accumulated_time, 'BestLossD': best_lossD, 'BestLossG': best_lossG, 'BestLossIterationG': best_loss_iter_G, 'BestLossF': best_lossF, 'BestLossIterationF': best_loss_iter_F, 'GroundTruth': dataset.extract_targets(float_data)[position].data.tolist(), 'Inputs': dataset.extract_inputs(float_data)[position].data.tolist() }, iteration) if is_ready_to_save(iteration): state = { 'iteration': iteration, 'stateD_dict': netD.state_dict(), 'stateG_dict': netG.state_dict(), 'stateF_dict': netF.state_dict(), 'best_lossD': best_lossD, 'best_lossG': best_lossG, 'total_time': accumulated_time, 'best_loss_iter_G': best_loss_iter_G, 'best_loss_iter_F': best_loss_iter_F } torch.save( state, os.path.join('/datatmp/Experiments/rohitgan/_logs', exp_batch, exp_alias, 'checkpoints', str(iteration) + '.pth')) if iteration == best_loss_iter_G and iteration > 10000: state = { 'iteration': iteration, 'stateD_dict': netD.state_dict(), 'stateG_dict': netG.state_dict(), 'stateF_dict': netF.state_dict(), 'best_lossD': best_lossD, 'best_lossG': best_lossG, 'total_time': accumulated_time, 'best_loss_iter_G': best_loss_iter_G } torch.save( state, os.path.join('/datatmp/Experiments/rohitgan/_logs', exp_batch, exp_alias, 'best_modelG' + '.pth')) if iteration == best_loss_iter_F and iteration > 10000: state = { 'iteration': iteration, 'stateD_dict': netD.state_dict(), 'stateG_dict': netG.state_dict(), 'stateF_dict': netF.state_dict(), 'best_lossD': best_lossD, 'best_lossG': best_lossG, 'best_lossF': best_lossF, 'total_time': accumulated_time, 'best_loss_iter_F': best_loss_iter_F } torch.save( state, os.path.join('/datatmp/Experiments/rohitgan/_logs', exp_batch, exp_alias, 'best_modelF' + '.pth')) iteration += 1