def train_imgf(opt): model = ImgFmodel().cuda() dataset = Robotdata.get_loader(opt) optimizer = torch.optim.Adam(model.parameters()) loss_fn = nn.MSELoss() for epoch in range(50): for i, item in enumerate(dataset): state, action, result = item[0] state = state.float().cuda() action = action.float().cuda() result = result.float().cuda() out = model(state, action) loss = loss_fn(out, result) * 100 print(epoch, i, loss.item()) optimizer.zero_grad() loss.backward() optimizer.step() torch.save(model.state_dict(), './imgpred.pth') # generate the batch image merge_img = [] for (before, after, pred) in zip(state, result, out): img = torch.cat([before, after, pred], 1) merge_img.append(img) merge_img = torch.cat(merge_img, 2).cpu() merge_img = (merge_img + 1) / 2 img = transforms.ToPILImage()(merge_img) # img = transforms.Resize((512, 640))(img) img.save(os.path.join('./tmp/imgpred', 'img_{}.jpg'.format(epoch)))
def __init__(self,opt): self.opt = opt self.isTrain = opt.istrain self.Tensor = torch.cuda.FloatTensor self.netG_A = state2state(opt=self.opt).cuda() self.netG_B = state2state(opt=self.opt).cuda() self.net_action_G_A = AGmodel(flag='A2B',opt=self.opt).cuda() self.net_action_G_B = AGmodel(flag='B2A',opt=self.opt).cuda() self.netF_A = Fmodel(self.opt).cuda() self.netF_B = ImgFmodel(opt=self.opt).cuda() self.dataF = Robotdata.get_loader(opt) self.train_forward_state(pretrained=opt.pretrain_f) #self.train_forward_img(pretrained=True) self.reset_buffer() # if self.isTrain: self.netD_A = stateDmodel(opt=self.opt).cuda() self.netD_B = stateDmodel(opt=self.opt).cuda() self.net_action_D_A = ADmodel(opt=self.opt).cuda() self.net_action_D_B = ADmodel(opt=self.opt).cuda() # if self.isTrain: self.fake_A_pool = ImagePool(pool_size=128) self.fake_B_pool = ImagePool(pool_size=128) self.fake_action_A_pool = ImagePool(pool_size=128) self.fake_action_B_pool = ImagePool(pool_size=128) # define loss functions self.criterionGAN = GANLoss(tensor=self.Tensor).cuda() if opt.loss == 'l1': self.criterionCycle = nn.L1Loss() elif opt.loss == 'l2': self.criterionCycle = nn.MSELoss() self.ImgcriterionCycle = nn.MSELoss() self.StatecriterionCycle = nn.L1Loss() # initialize optimizers parameters = [{'params':self.netF_A.parameters(),'lr':self.opt.F_lr}, {'params': self.netF_B.parameters(), 'lr': self.opt.F_lr}, {'params': self.netG_A.parameters(), 'lr': self.opt.G_lr}, {'params':self.netG_B.parameters(),'lr':self.opt.G_lr}, {'params': self.net_action_G_A.parameters(), 'lr': self.opt.A_lr}, {'params': self.net_action_G_B.parameters(), 'lr': self.opt.A_lr}] self.optimizer_G = torch.optim.Adam(parameters) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters()) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters()) self.optimizer_action_D_A = torch.optim.Adam(self.net_action_D_A.parameters()) self.optimizer_action_D_B = torch.optim.Adam(self.net_action_D_B.parameters()) print('---------- Networks initialized ---------------') print('-----------------------------------------------')