class CycleGANModel(): def __init__(self,opt): self.opt = opt self.isTrain = opt.istrain self.env = dmc2gym.make( domain_name=opt.domain_name, task_name=opt.task_name, seed=0, visualize_reward=False, from_pixels=True, height=256, width=256, frame_skip=opt.frame_skip ) self.env.seed(0) # self.state_dim = self.env.observation_space.shape[0] self.state_dim = self.env.observation_space.shape[0] if opt.state_dim==0 else opt.state_dim self.action_dim = self.env.action_space.shape[0] if self.opt.action_dim == 0: self.action_dim = self.env.action_space.shape[0] else: self.action_dim = self.opt.action_dim opt.state_dim = self.state_dim opt.action_dim = self.action_dim self.max_action = float(self.env.action_space.high[0]) self.img_policy = ImgPolicy(opt) self.Tensor = torch.cuda.FloatTensor self.netG_A = img2state(opt=self.opt).cuda() self.netG_B = img2state(opt=self.opt).cuda() self.net_action_G_A = AGmodel(flag='A2B',opt=self.opt).cuda() self.net_action_G_B = AGmodel(flag='B2A',opt=self.opt).cuda() self.netF_A = Fmodel(self.opt).cuda() self.reset_buffer() # if self.isTrain: self.netD_A = imgDmodel(opt=self.opt).cuda() self.netD_B = stateDmodel(opt=self.opt).cuda() self.net_action_D_A = ADmodel(opt=self.opt).cuda() self.net_action_D_B = ADmodel(opt=self.opt).cuda() # if self.isTrain: self.fake_A_pool = ImagePool(pool_size=128) self.fake_B_pool = ImagePool(pool_size=128) self.fake_action_A_pool = ImagePool(pool_size=128) self.fake_action_B_pool = ImagePool(pool_size=128) # define loss functions self.criterionGAN = GANLoss(tensor=self.Tensor).cuda() if opt.loss == 'l1': self.criterionCycle = nn.L1Loss() elif opt.loss == 'l2': self.criterionCycle = nn.MSELoss() else: self.criterionCycle = nn.SmoothL1Loss() self.ImgcriterionCycle = nn.MSELoss() self.StatecriterionCycle = nn.L1Loss() # initialize optimizers parameters = [{'params':self.netF_A.parameters(),'lr':self.opt.F_lr}, # {'params': self.netF_B.parameters(), 'lr': self.opt.F_lr}, # {'params': self.netG_A.parameters(), 'lr': self.opt.G_lr}, {'params':self.netG_B.parameters(),'lr':self.opt.G_lr},] # {'params': self.net_action_G_A.parameters(), 'lr': self.opt.A_lr}, # {'params': self.net_action_G_B.parameters(), 'lr': self.opt.A_lr}] self.optimizer_G = torch.optim.Adam(parameters) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters()) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters()) self.optimizer_action_D_A = torch.optim.Adam(self.net_action_D_A.parameters()) self.optimizer_action_D_B = torch.optim.Adam(self.net_action_D_B.parameters()) self.use_mask = opt.use_mask self.mask = np.array(opt.mask) self.mask = torch.tensor(self.mask).float() print('---------- Networks initialized ---------------') print('-----------------------------------------------') def parallel_init(self,device_ids=[0]): self.netG_B = torch.nn.DataParallel(self.netG_B,device_ids=device_ids) self.netF_A = torch.nn.DataParallel(self.netF_A,device_ids=device_ids) self.netD_A = torch.nn.DataParallel(self.netD_A,device_ids=device_ids) self.netD_B = torch.nn.DataParallel(self.netD_B,device_ids=device_ids) def train_forward_state(self,dataF,pretrained=False): if self.use_mask: weight_path = os.path.join(self.opt.log_root, '{}_{}_data'.format(self.opt.domain_name, self.opt.task_name), '{}_{}/pred_mask.pth'.format(self.opt.data_type1, self.opt.data_id1)) else: weight_path = os.path.join(self.opt.log_root, '{}_{}_data'.format(self.opt.domain_name, self.opt.task_name), '{}_{}/pred.pth'.format(self.opt.data_type1, self.opt.data_id1)) if pretrained: self.netF_A.load_state_dict(torch.load(weight_path)) print('forward model has loaded!') return None lr = 1e-3 optimizer = torch.optim.Adam(self.netF_A.parameters(),lr=lr) loss_fn = nn.L1Loss() data_size = len(dataF) for epoch in range(self.opt.f_epoch): epoch_loss, cmp_loss = 0, 0 if epoch in [3,7,10,15]: lr *= 0.5 optimizer = torch.optim.Adam(self.netF_A.parameters(), lr=lr) for i,item in enumerate(tqdm(dataF)): if i>data_size*0.8: continue state, action, result = item state = state.float().cuda() action = action.float().cuda() result = result.float().cuda() out = self.netF_A(state, action) if self.use_mask: loss = ((out-result)*(self.mask).cuda()).abs().mean() else: loss = loss_fn(out, result) # loss = loss_fn(out, result) optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss += loss.item() cmp_loss += loss_fn(state,result).item() print('epoch:{} loss:{:.7f} cmp_loss:{:.7f}' .format(epoch,epoch_loss/(0.8*data_size),cmp_loss/(0.8*data_size))) torch.save(self.netF_A.state_dict(), weight_path) print('forward model has been trained!') print('forward model starts to evaluate!') epoch_loss, cmp_loss = 0, 0 for i, item in enumerate(tqdm(dataF)): if i<data_size*0.8: continue state, action, result = item state = state.float().cuda() action = action.float().cuda() result = result.float().cuda() out = self.netF_A(state, action) loss = loss_fn(out, result) epoch_loss += loss.item() cmp_loss += loss_fn(state, result).item() print('loss:{:.7f} cmp_loss:{:.7f}'. format(epoch_loss/(0.2*data_size), cmp_loss/(0.2*data_size))) def set_input(self, input): # A is state self.input_A = input[1][0] # B is img self.input_Bt0 = input[0][0] self.input_Bt1 = input[0][2] self.action = input[0][1] self.gt0 = input[2][0].float().cuda() self.gt1 = input[2][1].float().cuda() def forward(self): self.real_A = Variable(self.input_A).float().cuda() self.real_Bt0 = Variable(self.input_Bt0).float().cuda() self.real_Bt1 = Variable(self.input_Bt1).float().cuda() self.action = Variable(self.action).float().cuda() def test(self): # forward self.forward() # G_A and G_B self.backward_G() self.backward_D_B() def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward if self.isTrain: loss_D.backward() return loss_D def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_At0.detach()) loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) self.loss_D_B = loss_D_B.item() def backward_G(self): lambda_G_B0 = self.opt.lambda_G0 lambda_G_B1 = self.opt.lambda_G1 lambda_G_B2 = self.opt.lambda_G2 lambda_F = self.opt.lambda_F # GAN loss D_B(G_B(B)) fake_At0 = self.netG_B(self.real_Bt0) pred_fake = self.netD_B(fake_At0) loss_G_Bt0 = self.criterionGAN(pred_fake, True) * lambda_G_B0 # GAN loss D_B(G_B(B)) fake_At1 = self.netF_A(fake_At0,self.action) pred_fake = self.netD_B(fake_At1) loss_G_Bt1 = self.criterionGAN(pred_fake, True) * lambda_G_B1 # cycle loss pred_At1 = self.netG_B(self.real_Bt1) cycle_label = torch.zeros_like(fake_At1).float().cuda() if self.use_mask: diff = (fake_At1 - pred_At1) * self.mask.cuda(device=fake_At1.device) else: diff = fake_At1 - pred_At1 loss_cycle = self.criterionCycle(diff, cycle_label) * lambda_F pred_fake = self.netD_B(pred_At1) loss_G_Bt2 = self.criterionGAN(pred_fake, True) * lambda_G_B2 self.loss_state_lt0 = nn.L1Loss()(fake_At0, self.gt0) self.loss_state_lt1 = nn.L1Loss()(pred_At1, self.gt1) # combined loss loss_G = loss_G_Bt0 + loss_G_Bt1 + loss_G_Bt2 + loss_cycle # loss_G = self.loss_state_lt0+self.loss_state_lt1 if self.isTrain: loss_G.backward() self.fake_At0 = fake_At0.data self.fake_At1 = fake_At1.data self.loss_G_Bt0 = loss_G_Bt0.item() self.loss_G_Bt1 = loss_G_Bt1.item() self.loss_cycle = loss_cycle.item() self.loss_state_lt0 = self.loss_state_lt0.item() self.loss_state_lt1 = self.loss_state_lt1.item() self.gt_buffer0.append(self.gt0.cpu().data.numpy()) self.pred_buffer0.append(self.fake_At0.cpu().data.numpy()) self.gt_buffer1.append(self.gt1.cpu().data.numpy()) self.pred_buffer1.append(self.fake_At1.cpu().data.numpy()) def optimize_parameters(self): # forward self.forward() # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_B self.optimizer_D_B.zero_grad() self.backward_D_B() self.optimizer_D_B.step() self.push_current_errors() def push_current_errors(self): ret_errors = OrderedDict([('L_t0',self.loss_state_lt0), ('L_t1',self.loss_state_lt1), ('D_B', self.loss_D_B), ('G_B0', self.loss_G_Bt0), ('G_B1', self.loss_G_Bt1), ('Cyc', self.loss_cycle)]) self.error.append(ret_errors) def get_current_errors(self): ret_errors = OrderedDict([('L_t0',self.loss_state_lt0), ('L_t1',self.loss_state_lt1), ('D_B', self.loss_D_B), ('G_B0', self.loss_G_Bt0), ('G_B1', self.loss_G_Bt1), ('Cyc', self.loss_cycle)]) for errors in self.error: for key, value in errors.items(): ret_errors[key] += value for key, value in ret_errors.items(): ret_errors[key] /= (len(self.error)+1) self.error = [] return ret_errors # helper saving function that can be used by subclasses def save_network(self, network, network_label, path): save_filename = 'model_{}.pth'.format(network_label) save_path = os.path.join(path, save_filename) torch.save(network.state_dict(), save_path) def save(self, path): self.save_network(self.netG_B, 'G_B', path) self.save_network(self.netD_B, 'D_B', path) self.save_network(self.netG_A, 'G_A', path) self.save_network(self.netD_A, 'D_A', path) self.save_network(self.net_action_G_B, 'action_G_B', path) self.save_network(self.net_action_D_B, 'action_D_B', path) self.save_network(self.net_action_G_A, 'action_G_A', path) self.save_network(self.net_action_D_A, 'action_D_A', path) def load_network(self, network, network_label, path): weight_filename = 'model_{}.pth'.format(network_label) weight_path = os.path.join(path, weight_filename) network.load_state_dict(torch.load(weight_path)) def load(self,path): self.load_network(self.netG_B, 'G_B', path) self.load_network(self.netD_B, 'D_B', path) self.load_network(self.netG_A, 'G_A', path) self.load_network(self.netD_A, 'D_A', path) self.load_network(self.net_action_G_B, 'action_G_B', path) self.load_network(self.net_action_D_B, 'action_D_B', path) self.load_network(self.net_action_G_A, 'action_G_A', path) self.load_network(self.net_action_D_A, 'action_D_A', path) def show_points(self,gt_data,pred_data): print(abs(gt_data-pred_data).mean(0)) ncols = int(np.sqrt(gt_data.shape[1]))+1 nrows = int(np.sqrt(gt_data.shape[1]))+1 assert (ncols*nrows>=gt_data.shape[1]) _, axes = plt.subplots(ncols, nrows, figsize=(nrows * 3, ncols * 3)) axes = axes.flatten() for ax_i, ax in enumerate(axes): if ax_i>=gt_data.shape[1]: continue ax.scatter(gt_data[:, ax_i], pred_data[:, ax_i], s=3, label='xyz_{}'.format(ax_i)) def npdata(self,item): return item.cpu().data.numpy() def reset_buffer(self): self.gt_buffer0 = [] self.pred_buffer0 = [] self.gt_buffer1 = [] self.pred_buffer1 = [] self.error = [] def visual(self,path): gt_data = np.vstack(self.gt_buffer0) pred_data = np.vstack(self.pred_buffer0) self.show_points(gt_data,pred_data) # plt.legend() plt.savefig(path) plt.cla() plt.clf() gt_data = np.vstack(self.gt_buffer1) pred_data = np.vstack(self.pred_buffer1) self.show_points(gt_data, pred_data) # plt.legend() plt.savefig(path.replace('.jpg','_step1.jpg')) self.reset_buffer()
class SSCycleGANModel(): def __init__(self,opt): self.opt = opt self.isTrain = opt.istrain self.Tensor = torch.cuda.FloatTensor self.netG_A = state2state(opt=self.opt).cuda() self.netG_B = state2state(opt=self.opt).cuda() self.net_action_G_A = AGmodel(flag='A2B',opt=self.opt).cuda() self.net_action_G_B = AGmodel(flag='B2A',opt=self.opt).cuda() self.netF_A = Fmodel(self.opt).cuda() self.netF_B = ImgFmodel(opt=self.opt).cuda() self.dataF = Robotdata.get_loader(opt) self.train_forward_state(pretrained=opt.pretrain_f) #self.train_forward_img(pretrained=True) self.reset_buffer() # if self.isTrain: self.netD_A = stateDmodel(opt=self.opt).cuda() self.netD_B = stateDmodel(opt=self.opt).cuda() self.net_action_D_A = ADmodel(opt=self.opt).cuda() self.net_action_D_B = ADmodel(opt=self.opt).cuda() # if self.isTrain: self.fake_A_pool = ImagePool(pool_size=128) self.fake_B_pool = ImagePool(pool_size=128) self.fake_action_A_pool = ImagePool(pool_size=128) self.fake_action_B_pool = ImagePool(pool_size=128) # define loss functions self.criterionGAN = GANLoss(tensor=self.Tensor).cuda() if opt.loss == 'l1': self.criterionCycle = nn.L1Loss() elif opt.loss == 'l2': self.criterionCycle = nn.MSELoss() self.ImgcriterionCycle = nn.MSELoss() self.StatecriterionCycle = nn.L1Loss() # initialize optimizers parameters = [{'params':self.netF_A.parameters(),'lr':self.opt.F_lr}, {'params': self.netF_B.parameters(), 'lr': self.opt.F_lr}, {'params': self.netG_A.parameters(), 'lr': self.opt.G_lr}, {'params':self.netG_B.parameters(),'lr':self.opt.G_lr}, {'params': self.net_action_G_A.parameters(), 'lr': self.opt.A_lr}, {'params': self.net_action_G_B.parameters(), 'lr': self.opt.A_lr}] self.optimizer_G = torch.optim.Adam(parameters) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters()) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters()) self.optimizer_action_D_A = torch.optim.Adam(self.net_action_D_A.parameters()) self.optimizer_action_D_B = torch.optim.Adam(self.net_action_D_B.parameters()) print('---------- Networks initialized ---------------') print('-----------------------------------------------') def train_forward_state(self,pretrained=False): weight_path = os.path.join(self.opt.data_root,'data_{}/pred.pth'.format(self.opt.test_id1)) if pretrained: self.netF_A.load_state_dict(torch.load(weight_path)) print('forward model has loaded!') return None optimizer = torch.optim.Adam(self.netF_A.parameters(),lr=1e-3) loss_fn = nn.L1Loss() for epoch in range(50): epoch_loss = 0 for i,item in enumerate(tqdm(self.dataF)): state, action, result = item[1] state = state.float().cuda() action = action.float().cuda() result = result.float().cuda() out = self.netF_A(state, action) loss = loss_fn(out, result) optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss += loss.item() print('epoch:{} loss:{:.7f}'.format(epoch,epoch_loss/len(self.dataF))) torch.save(self.netF_A.state_dict(), weight_path) print('forward model has been trained!') def train_forward_img(self,pretrained=False): weight_path = './model/imgpred.pth' if pretrained: self.netF_B.load_state_dict(torch.load(weight_path)) return None optimizer = torch.optim.Adam(self.netF_B.parameters(),lr=1e-3) loss_fn = nn.MSELoss() for epoch in range(50): epoch_loss = 0 for i,item in enumerate(tqdm(self.dataF)): state, action, result = item[1] state = state.float().cuda() action = action.float().cuda() result = result.float().cuda() out = self.netF_B(state, action)*100 loss = loss_fn(out, result) optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss += loss.item() print('epoch:{} loss:{:.7f}'.format(epoch,epoch_loss/len(self.dataF))) torch.save(self.netF_B.state_dict(), weight_path) print('forward model has been trained!') def set_input(self, input): # A is state self.input_A = input[1][0] # B is img self.input_Bt0 = input[2][0] self.input_Bt1 = input[2][1] self.action = input[0][1] self.gt0 = input[2][0].float().cuda() self.gt1 = input[2][1].float().cuda() def forward(self): self.real_A = Variable(self.input_A).float().cuda() self.real_Bt0 = Variable(self.input_Bt0).float().cuda() self.real_Bt1 = Variable(self.input_Bt1).float().cuda() self.action = Variable(self.action).float().cuda() def test(self): # forward self.forward() # G_A and G_B self.backward_G() self.backward_D_B() def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward if self.isTrain: loss_D.backward() return loss_D def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_At0) loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) self.loss_D_B = loss_D_B.item() def backward_G(self): lambda_G_B0 = self.opt.lambda_G0 lambda_G_B1 = self.opt.lambda_G1 lambda_F = self.opt.lambda_F lambda_C = 100. # GAN loss D_B(G_B(B)) fake_At0 = self.netG_B(self.real_Bt0) pred_fake = self.netD_B(fake_At0) loss_G_Bt0 = self.criterionGAN(pred_fake, True) * lambda_G_B0 rec_At0 = self.netG_A(fake_At0) loss_cycle_original_A = self.criterionCycle(rec_At0,self.real_Bt0) * lambda_C # GAN loss D_B(G_B(B)) fake_Bt0 = self.netG_A(self.real_A) pred_fake = self.netD_A(fake_Bt0) loss_G_At0 = self.criterionGAN(pred_fake, True) * lambda_G_B0 rec_Bt0 = self.netG_B(fake_Bt0) loss_cycle_original_B = self.criterionCycle(rec_Bt0,self.real_A) * lambda_C # GAN loss D_B(G_B(B)) fake_At1 = self.netF_A(fake_At0,self.action) pred_fake = self.netD_B(fake_At1) loss_G_Bt1 = self.criterionGAN(pred_fake, True) * lambda_G_B1 # cycle loss pred_At1 = self.netG_B(self.real_Bt1) cycle_label = torch.zeros_like(fake_At1).float().cuda() loss_cycle = self.criterionCycle(fake_At1-pred_At1,cycle_label) * lambda_F self.loss_state_lt0 = self.criterionCycle(fake_At0, self.gt0) self.loss_state_lt1 = self.criterionCycle(pred_At1, self.gt1) # combined loss loss_cycle_original = loss_cycle_original_A + loss_cycle_original_B + loss_G_At0 loss_G = loss_G_Bt0 + loss_G_Bt1 + loss_cycle + loss_cycle_original loss_G = loss_G_Bt0 + loss_G_Bt1 + loss_cycle # loss_G = self.loss_state_lt0*10+self.loss_state_lt1*10 if self.isTrain: loss_G.backward() self.fake_At0 = fake_At0.data self.fake_At1 = fake_At1.data self.loss_G_Bt0 = loss_G_Bt0.item() self.loss_G_Bt1 = loss_G_Bt1.item() self.loss_cycle = loss_cycle.item() self.loss_state_lt0 = self.loss_state_lt0.item() self.loss_state_lt1 = self.loss_state_lt1.item() self.gt_buffer0.append(self.gt0.cpu().data.numpy()) self.pred_buffer0.append(self.fake_At0.cpu().data.numpy()) self.gt_buffer1.append(self.gt1.cpu().data.numpy()) self.pred_buffer1.append(self.fake_At1.cpu().data.numpy()) def optimize_parameters(self): # forward self.forward() # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_B self.optimizer_D_B.zero_grad() self.backward_D_B() self.optimizer_D_B.step() def get_current_errors(self): ret_errors = OrderedDict([('L_t0',self.loss_state_lt0), ('L_t1',self.loss_state_lt1), ('D_B', self.loss_D_B), ('G_B0', self.loss_G_Bt0), ('G_B1', self.loss_G_Bt1), ('Cyc', self.loss_cycle)]) return ret_errors # helper saving function that can be used by subclasses def save_network(self, network, network_label, path): save_filename = 'model_{}.pth'.format(network_label) save_path = os.path.join(path, save_filename) torch.save(network.state_dict(), save_path) def save(self, path): self.save_network(self.netG_B, 'G_B', path) self.save_network(self.netD_B, 'D_B', path) self.save_network(self.netG_A, 'G_A', path) self.save_network(self.netD_A, 'D_A', path) self.save_network(self.net_action_G_B, 'action_G_B', path) self.save_network(self.net_action_D_B, 'action_D_B', path) self.save_network(self.net_action_G_A, 'action_G_A', path) self.save_network(self.net_action_D_A, 'action_D_A', path) def load_network(self, network, network_label, path): weight_filename = 'model_{}.pth'.format(network_label) weight_path = os.path.join(path, weight_filename) network.load_state_dict(torch.load(weight_path)) def load(self,path): self.load_network(self.netG_B, 'G_B', path) self.load_network(self.netD_B, 'D_B', path) self.load_network(self.netG_A, 'G_A', path) self.load_network(self.netD_A, 'D_A', path) self.load_network(self.net_action_G_B, 'action_G_B', path) self.load_network(self.net_action_D_B, 'action_D_B', path) self.load_network(self.net_action_G_A, 'action_G_A', path) self.load_network(self.net_action_D_A, 'action_D_A', path) def show_points(self,gt_data,pred_data): print(abs(gt_data-pred_data).mean(0)) ncols = int(np.sqrt(gt_data.shape[1])) nrows = int(np.sqrt(gt_data.shape[1]))+1 assert (ncols*nrows>=gt_data.shape[1]) _, axes = plt.subplots(ncols, nrows, figsize=(nrows * 3, ncols * 3)) axes = axes.flatten() for ax_i, ax in enumerate(axes): if ax_i>=gt_data.shape[1]: continue ax.scatter(gt_data[:, ax_i], pred_data[:, ax_i], s=3, label='xyz_{}'.format(ax_i)) def npdata(self,item): return item.cpu().data.numpy() def reset_buffer(self): self.gt_buffer0 = [] self.pred_buffer0 = [] self.gt_buffer1 = [] self.pred_buffer1 = [] def visual(self,path): gt_data = np.vstack(self.gt_buffer0) pred_data = np.vstack(self.pred_buffer0) self.show_points(gt_data,pred_data) plt.legend() plt.savefig(path) plt.cla() plt.clf() gt_data = np.vstack(self.gt_buffer1) pred_data = np.vstack(self.pred_buffer1) self.show_points(gt_data, pred_data) plt.legend() plt.savefig(path.replace('.jpg','_step1.jpg')) self.reset_buffer()
class ActionCycleGANModel(): def __init__(self, opt): self.opt = opt self.isTrain = opt.istrain self.Tensor = torch.cuda.FloatTensor self.netG_A = state2img(opt=self.opt).cuda() self.netG_B = img2state(opt=self.opt).cuda() self.net_action_G_A = AGmodel(flag='A2B', opt=self.opt).cuda() self.net_action_G_B = AGmodel(flag='B2A', opt=self.opt).cuda() self.netF_A = Fmodel(self.opt).cuda() self.netF_B = ImgFmodel(opt=self.opt).cuda() self.dataF = Robotdata.get_loader(opt) self.train_forward_state(pretrained=opt.pretrain_f) #self.train_forward_img(pretrained=True) self.reset_buffer() # if self.isTrain: self.netD_A = imgDmodel(opt=self.opt).cuda() self.netD_B = stateDmodel(opt=self.opt).cuda() self.net_action_D_A = ADmodel(opt=self.opt).cuda() self.net_action_D_B = ADmodel(opt=self.opt).cuda() # if self.isTrain: self.fake_A_pool = ImagePool(pool_size=128) self.fake_B_pool = ImagePool(pool_size=128) self.fake_action_A_pool = ImagePool(pool_size=128) self.fake_action_B_pool = ImagePool(pool_size=128) # define loss functions self.criterionGAN = GANLoss(tensor=self.Tensor).cuda() if opt.loss == 'l1': self.criterionCycle = nn.L1Loss() elif opt.loss == 'l2': self.criterionCycle = nn.MSELoss() self.ImgcriterionCycle = nn.MSELoss() self.StatecriterionCycle = nn.L1Loss() # initialize optimizers parameters = [{ 'params': self.netF_A.parameters(), 'lr': self.opt.F_lr }, { 'params': self.netF_B.parameters(), 'lr': self.opt.F_lr }, { 'params': self.netG_A.parameters(), 'lr': self.opt.G_lr }, { 'params': self.netG_B.parameters(), 'lr': self.opt.G_lr }, { 'params': self.net_action_G_A.parameters(), 'lr': self.opt.G_lr }, { 'params': self.net_action_G_B.parameters(), 'lr': self.opt.G_lr }] self.optimizer_G = torch.optim.Adam(parameters) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters()) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters()) self.optimizer_action_D_A = torch.optim.Adam( self.net_action_D_A.parameters()) self.optimizer_action_D_B = torch.optim.Adam( self.net_action_D_B.parameters()) print('---------- Networks initialized ---------------') print('-----------------------------------------------') def train_forward_state(self, pretrained=False): weight_path = './model/pred.pth' if pretrained: self.netF_A.load_state_dict(torch.load(weight_path)) return None optimizer = torch.optim.Adam(self.netF_A.parameters(), lr=1e-3) loss_fn = nn.L1Loss() for epoch in range(50): epoch_loss = 0 for i, item in enumerate(tqdm(self.dataF)): state, action, result = item[1] state = state.float().cuda() action = action.float().cuda() result = result.float().cuda() out = self.netF_A(state, action) loss = loss_fn(out, result) optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss += loss.item() print('epoch:{} loss:{:.7f}'.format(epoch, epoch_loss / len(self.dataF))) torch.save(self.netF_A.state_dict(), weight_path) print('forward model has been trained!') def train_forward_img(self, pretrained=False): weight_path = './model/imgpred.pth' if pretrained: self.netF_B.load_state_dict(torch.load(weight_path)) return None optimizer = torch.optim.Adam(self.netF_B.parameters(), lr=1e-3) loss_fn = nn.MSELoss() for epoch in range(50): epoch_loss = 0 for i, item in enumerate(tqdm(self.dataF)): state, action, result = item[1] state = state.float().cuda() action = action.float().cuda() result = result.float().cuda() out = self.netF_B(state, action) * 100 loss = loss_fn(out, result) optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss += loss.item() print('epoch:{} loss:{:.7f}'.format(epoch, epoch_loss / len(self.dataF))) torch.save(self.netF_B.state_dict(), weight_path) print('forward model has been trained!') def set_input(self, input): # A is state self.input_At0 = input[1][0] self.input_At1 = input[1][2] self.input_action_A = input[1][1] # B is img self.input_Bt0 = input[0][0] self.input_Bt1 = input[0][2] self.input_action_B = input[0][1] self.gt0 = input[2][0].float().cuda() self.gt1 = input[2][1].float().cuda() def forward(self): self.real_At0 = Variable(self.input_At0).float().cuda() self.real_At1 = Variable(self.input_At1).float().cuda() self.real_Bt0 = Variable(self.input_Bt0).float().cuda() self.real_Bt1 = Variable(self.input_Bt1).float().cuda() self.action_A = Variable(self.input_action_A).float().cuda() self.action_B = Variable(self.input_action_B).float().cuda() def test(self): # forward self.forward() # G_A and G_B self.backward_G() self.backward_D_B() def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward if self.isTrain: loss_D.backward() return loss_D def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_At0) loss_D_B = self.backward_D_basic(self.netD_B, self.real_At0, fake_A) self.loss_D_B = loss_D_B.item() def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_Bt0) loss_D_A = self.backward_D_basic(self.netD_A, self.real_Bt0, fake_B) self.loss_D_A = loss_D_A.item() def backward_action_D_B(self): fake_action_A = self.fake_action_A_pool.query(self.fake_action_A) loss_action_D_B = self.backward_D_basic(self.net_action_D_B, self.action_A, fake_action_A) self.loss_action_D_B = loss_action_D_B.item() def backward_action_D_A(self): fake_action_B = self.fake_action_B_pool.query(self.fake_action_B) loss_action_D_A = self.backward_D_basic(self.net_action_D_A, self.action_B, fake_action_B) self.loss_action_D_A = loss_action_D_A.item() def backward_G(self): lambda_idt = 0.2 lambda_C = self.opt.lambda_C lambda_G_B0 = 50.0 lambda_G_B1 = 50.0 lambda_G_action = 50. lambda_F = self.opt.lambda_F lambda_AC = self.opt.lambda_AC lambda_R = self.opt.lambda_R lambda_A_balance = 1.0 # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. idt_A = self.net_action_G_A(self.action_B) loss_idt_A = self.criterionCycle( idt_A, self.action_B) * lambda_AC * lambda_idt # G_B should be identity if real_A is fed. idt_B = self.net_action_G_B(self.action_A) loss_idt_B = self.criterionCycle( idt_B, self.action_A) * lambda_AC * lambda_idt self.idt_A = idt_A.data self.idt_B = idt_B.data self.loss_idt_A = loss_idt_A.item() self.loss_idt_B = loss_idt_B.item() else: loss_idt_A = 0 loss_idt_B = 0 self.loss_idt_A = 0 self.loss_idt_B = 0 """ GAN loss series """ # GAN loss D_B(G_B(B)) for action fake_action_A = self.net_action_G_B(self.action_B) pred_fake = self.net_action_D_B(fake_action_A) loss_action_G_B = self.criterionGAN(pred_fake, True) * lambda_G_action # GAN loss D_A(G_A(A)) for action fake_action_B = self.net_action_G_A(self.action_A) pred_fake = self.net_action_D_A(fake_action_B) loss_action_G_A = self.criterionGAN(pred_fake, True) * lambda_G_action loss_gan_original = loss_action_G_B + loss_action_G_A + self.loss_idt_A + self.loss_idt_B """ Cycle loss series """ # Backward cycle loss for action_A rec_action_B = self.net_action_G_A(fake_action_A) loss_cycle_action_B = self.criterionCycle(rec_action_B, self.action_B) * lambda_AC # Backward cycle loss for action_B rec_action_A = self.net_action_G_B(fake_action_B) loss_cycle_action_A = self.criterionCycle(rec_action_A, self.action_A) * lambda_AC loss_cycle_original = loss_cycle_action_B + loss_cycle_action_A # combined loss loss_G = loss_gan_original + loss_cycle_original if self.isTrain: loss_G.backward() self.fake_At0 = self.gt0.data self.fake_At1 = self.gt1.data self.fake_Bt0 = self.gt0.data self.fake_Bt1 = self.gt1.data self.fake_action_A = fake_action_A.data self.fake_action_B = fake_action_B.data self.loss_G_action_B = loss_action_G_B.item() self.loss_G_action_A = loss_action_G_A.item() self.loss_cycle_action_A = loss_cycle_action_A.item() self.loss_cycle_action_B = loss_cycle_action_B.item() self.loss_state_lt0 = self.criterionCycle(self.fake_At0, self.gt0).item() self.loss_state_lt1 = self.criterionCycle(self.fake_At1, self.gt1).item() self.gt_buffer.append(self.gt0.cpu().data.numpy()) self.gt_buffer.append(self.gt1.cpu().data.numpy()) self.pred_buffer.append(self.fake_At0.cpu().data.numpy()) self.pred_buffer.append(self.fake_At1.cpu().data.numpy()) self.realA_buffer.append(self.action_A.cpu().data.numpy()) self.fakeA_buffer.append(self.fake_action_B.cpu().data.numpy()) self.realB_buffer.append(self.action_B.cpu().data.numpy()) self.fakeB_buffer.append(self.fake_action_A.cpu().data.numpy()) def optimize_parameters(self): # forward self.forward() # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # action_D_B self.optimizer_action_D_B.zero_grad() self.backward_action_D_B() self.optimizer_action_D_B.step() # action_D_A self.optimizer_action_D_A.zero_grad() self.backward_action_D_A() self.optimizer_action_D_A.step() def get_current_errors(self): ret_errors = OrderedDict([('L_t0', self.loss_state_lt0), ('L_t1', self.loss_state_lt1), ('D_action_B', self.loss_action_D_B), ('D_action_A', self.loss_action_D_A), ('G_action_B', self.loss_G_action_B), ('G_action_A', self.loss_G_action_A), ('Cyc_action_B', self.loss_cycle_action_B), ('Cyc_action_A', self.loss_cycle_action_A)]) return ret_errors # helper saving function that can be used by subclasses def save_network(self, network, network_label, path): save_filename = 'model_{}.pth'.format(network_label) save_path = os.path.join(path, save_filename) torch.save(network.state_dict(), save_path) def save(self, path): self.save_network(self.net_action_G_B, 'action_G_B', path) self.save_network(self.net_action_D_B, 'action_D_B', path) self.save_network(self.net_action_G_A, 'action_G_A', path) self.save_network(self.net_action_D_A, 'action_D_A', path) def load_network(self, network, network_label, path): weight_filename = 'model_{}.pth'.format(network_label) weight_path = os.path.join(path, weight_filename) network.load_state_dict(torch.load(weight_path)) def load(self, path): self.load_network(self.netG_B, 'G_B', path) self.load_network(self.netD_B, 'D_B', path) self.load_network(self.netG_A, 'G_A', path) self.load_network(self.netD_A, 'D_A', path) self.load_network(self.net_action_G_B, 'action_G_B', path) self.load_network(self.net_action_D_B, 'action_D_B', path) self.load_network(self.net_action_G_A, 'action_G_A', path) self.load_network(self.net_action_D_A, 'action_D_A', path) def show_points(self): # num_images = min(imgs.shape[0],num_images) ncols = 2 nrows = 4 _, axes = plt.subplots(ncols, nrows, figsize=(nrows * 3, ncols * 3)) axes = axes.flatten() gt_data = np.vstack(self.gt_buffer) pred_data = np.vstack(self.pred_buffer) print(abs(gt_data - pred_data).mean(0)) realA = np.vstack(self.realA_buffer) fakeA = np.vstack(self.fakeA_buffer) realB = np.vstack(self.realB_buffer) fakeB = np.vstack(self.fakeB_buffer) for ax_i, ax in enumerate(axes): if ax_i < nrows: ax.scatter(realA[:, ax_i], fakeA[:, ax_i], s=3, label='action A') else: ax.scatter(realB[:, ax_i - nrows], fakeB[:, ax_i - nrows], s=3, label='action B') def npdata(self, item): return item.cpu().data.numpy() def reset_buffer(self): self.gt_buffer = [] self.pred_buffer = [] self.realA_buffer = [] self.fakeA_buffer = [] self.realB_buffer = [] self.fakeB_buffer = [] def visual(self, path): # plt.xlim(-4,4) # plt.ylim(-1.5,1.5) self.show_points() plt.legend() plt.savefig(path) plt.cla() plt.clf() self.reset_buffer()