def main(args): # reconstruction loss # (1 / L) * torch.sum(torch.log()) # mse? batch_size = args.batch_size dataset_size = args.total_data_size # cfg_file = os.path.join(args.example_config_path, args.primitive) + ".yaml" cfg = get_vae_defaults() # cfg.merge_from_file(cfg_file) cfg.freeze() vae = VAE(args.input_dimension, args.output_dimension, args.latent_dimension, hidden_layers=cfg.ENCODER_HIDDEN_LAYERS_MLP) if torch.cuda.is_available(): vae.encoder.cuda() vae.decoder.cuda() data_dir = args.data_dir data_loader = DataLoader(data_dir=data_dir) data_loader.create_random_ordering(size=dataset_size) total_loss = 0 for epoch in range(args.num_epochs): print('Epoch: ' + str(epoch)) running_total_loss = 0 for i in range(0, dataset_size, batch_size): vae.optimizer.zero_grad() input_batch, target_batch = data_loader.load_batch(i, batch_size) input_batch = to_var(torch.from_numpy(input_batch)) target_batch = to_var(torch.from_numpy(target_batch)) z, recon_mu, z_mu, z_logvar = vae.forward(input_batch) # target = torch.normal(z_mu) output = recon_mu kl_loss = vae.kl_loss(z_mu, z_logvar) recon_loss = vae.recon_loss(output, target_batch) # loss = vae.total_loss(output, target_batch, z_mu, z_logvar) loss = kl_loss + recon_loss loss.backward() vae.optimizer.step() running_total_loss = running_total_loss + loss.data print(" -- running loss: ") print(running_total_loss) print(' --avgerage loss: ') print(running_total_loss / (dataset_size / batch_size)) total_loss = total_loss + running_total_loss
def main(args): # cfg_file = os.path.join(args.example_config_path, args.primitive) + ".yaml" cfg = get_vae_defaults() # cfg.merge_from_file(cfg_file) cfg.freeze() batch_size = args.batch_size dataset_size = args.total_data_size if args.experiment_name is None: experiment_name = args.model_name else: experiment_name = args.experiment_name if not os.path.exists(os.path.join(args.log_dir, experiment_name)): os.makedirs(os.path.join(args.log_dir, experiment_name)) description_txt = raw_input('Please enter experiment notes: \n') if isinstance(description_txt, str): with open( os.path.join(args.log_dir, experiment_name, experiment_name + '_description.txt'), 'wb') as f: f.write(description_txt) writer = SummaryWriter(os.path.join(args.log_dir, experiment_name)) # torch_seed = np.random.randint(low=0, high=1000) # np_seed = np.random.randint(low=0, high=1000) torch_seed = 0 np_seed = 0 torch.manual_seed(torch_seed) np.random.seed(np_seed) trained_model_path = os.path.join(args.model_path, args.model_name) if not os.path.exists(trained_model_path): os.makedirs(trained_model_path) if args.task == 'contact': if args.start_rep == 'keypoints': start_dim = 24 elif args.start_rep == 'pose': start_dim = 7 if args.goal_rep == 'keypoints': goal_dim = 24 elif args.goal_rep == 'pose': goal_dim = 7 if args.skill_type == 'pull': # + 7 because single arm palm pose input_dim = start_dim + goal_dim + 7 else: # + 14 because both arms palm pose input_dim = start_dim + goal_dim + 14 output_dim = 7 decoder_input_dim = start_dim + goal_dim vae = VAE(input_dim, output_dim, args.latent_dimension, decoder_input_dim, hidden_layers=cfg.ENCODER_HIDDEN_LAYERS_MLP, lr=args.learning_rate) elif args.task == 'goal': if args.start_rep == 'keypoints': start_dim = 24 elif args.start_rep == 'pose': start_dim = 7 if args.goal_rep == 'keypoints': goal_dim = 24 elif args.goal_rep == 'pose': goal_dim = 7 input_dim = start_dim + goal_dim output_dim = goal_dim decoder_input_dim = start_dim vae = GoalVAE(input_dim, output_dim, args.latent_dimension, decoder_input_dim, hidden_layers=cfg.ENCODER_HIDDEN_LAYERS_MLP, lr=args.learning_rate) elif args.task == 'transformation': input_dim = args.input_dimension output_dim = args.output_dimension decoder_input_dim = args.input_dimension - args.output_dimension vae = GoalVAE(input_dim, output_dim, args.latent_dimension, decoder_input_dim, hidden_layers=cfg.ENCODER_HIDDEN_LAYERS_MLP, lr=args.learning_rate) else: raise ValueError('training task not recognized') if torch.cuda.is_available(): vae.encoder.cuda() vae.decoder.cuda() if args.start_epoch > 0: start_epoch = args.start_epoch num_epochs = args.num_epochs fname = os.path.join( trained_model_path, args.model_name + '_epoch_%d.pt' % args.start_epoch) torch_seed, np_seed = load_seed(fname) load_net_state(vae, fname) load_opt_state(vae, fname) args = load_args(fname) args.start_epoch = start_epoch args.num_epochs = num_epochs torch.manual_seed(torch_seed) np.random.seed(np_seed) data_dir = args.data_dir data_loader = DataLoader(data_dir=data_dir) data_loader.create_random_ordering(size=dataset_size) dataset = data_loader.load_dataset(start_rep=args.start_rep, goal_rep=args.goal_rep, task=args.task) total_loss = [] start_time = time.time() print('Saving models to: ' + trained_model_path) kl_weight = 1.0 print('Starting on epoch: ' + str(args.start_epoch)) for epoch in range(args.start_epoch, args.start_epoch + args.num_epochs): print('Epoch: ' + str(epoch)) epoch_total_loss = 0 epoch_kl_loss = 0 epoch_pos_loss = 0 epoch_ori_loss = 0 epoch_recon_loss = 0 kl_coeff = 1 - kl_weight kl_weight = args.kl_anneal_rate * kl_weight print('KL coeff: ' + str(kl_coeff)) for i in range(0, dataset_size, batch_size): vae.optimizer.zero_grad() input_batch, decoder_input_batch, target_batch = \ data_loader.sample_batch(dataset, i, batch_size) input_batch = to_var(torch.from_numpy(input_batch)) decoder_input_batch = to_var(torch.from_numpy(decoder_input_batch)) z, recon_mu, z_mu, z_logvar = vae.forward(input_batch, decoder_input_batch) kl_loss = vae.kl_loss(z_mu, z_logvar) if args.task == 'contact': output_r, output_l = recon_mu if args.skill_type == 'grasp': target_batch_right = to_var( torch.from_numpy(target_batch[:, 0])) target_batch_left = to_var( torch.from_numpy(target_batch[:, 1])) pos_loss_right = vae.mse(output_r[:, :3], target_batch_right[:, :3]) ori_loss_right = vae.rotation_loss( output_r[:, 3:], target_batch_right[:, 3:]) pos_loss_left = vae.mse(output_l[:, :3], target_batch_left[:, :3]) ori_loss_left = vae.rotation_loss(output_l[:, 3:], target_batch_left[:, 3:]) pos_loss = pos_loss_left + pos_loss_right ori_loss = ori_loss_left + ori_loss_right elif args.skill_type == 'pull': target_batch = to_var( torch.from_numpy(target_batch.squeeze())) #TODO add flags for when we're training both arms # output = recon_mu[0] # right arm is index [0] # output = recon_mu[1] # left arm is index [1] pos_loss_right = vae.mse(output_r[:, :3], target_batch[:, :3]) ori_loss_right = vae.rotation_loss(output_r[:, 3:], target_batch[:, 3:]) pos_loss = pos_loss_right ori_loss = ori_loss_right elif args.task == 'goal': target_batch = to_var(torch.from_numpy(target_batch.squeeze())) output = recon_mu if args.goal_rep == 'pose': pos_loss = vae.mse(output[:, :3], target_batch[:, :3]) ori_loss = vae.rotation_loss(output[:, 3:], target_batch[:, 3:]) elif args.goal_rep == 'keypoints': pos_loss = vae.mse(output, target_batch) ori_loss = torch.zeros(pos_loss.shape) elif args.task == 'transformation': target_batch = to_var(torch.from_numpy(target_batch.squeeze())) output = recon_mu pos_loss = vae.mse(output[:, :3], target_batch[:, :3]) ori_loss = vae.rotation_loss(output[:, 3:], target_batch[:, 3:]) recon_loss = pos_loss + ori_loss loss = kl_coeff * kl_loss + recon_loss loss.backward() vae.optimizer.step() epoch_total_loss = epoch_total_loss + loss.data epoch_kl_loss = epoch_kl_loss + kl_loss.data epoch_pos_loss = epoch_pos_loss + pos_loss.data epoch_ori_loss = epoch_ori_loss + ori_loss.data epoch_recon_loss = epoch_recon_loss + recon_loss.data writer.add_scalar('loss/train/ori_loss', ori_loss.data, i) writer.add_scalar('loss/train/pos_loss', pos_loss.data, i) writer.add_scalar('loss/train/kl_loss', kl_loss.data, i) if (i / batch_size) % args.batch_freq == 0: if args.skill_type == 'pull' or args.task == 'goal' or args.task == 'transformation': print( 'Train Epoch: %d [%d/%d (%f)]\tLoss: %f\tKL: %f\tPos: %f\t Ori: %f' % (epoch, i, dataset_size, 100.0 * i / dataset_size / batch_size, loss.item(), kl_loss.item(), pos_loss.item(), ori_loss.item())) elif args.skill_type == 'grasp' and args.task == 'contact': print( 'Train Epoch: %d [%d/%d (%f)]\tLoss: %f\tKL: %f\tR Pos: %f\t R Ori: %f\tL Pos: %f\tL Ori: %f' % (epoch, i, dataset_size, 100.0 * i / dataset_size / batch_size, loss.item(), kl_loss.item(), pos_loss_right.item(), ori_loss_right.item(), pos_loss_left.item(), ori_loss_left.item())) print(' --avgerage loss: ') print(epoch_total_loss / (dataset_size / batch_size)) loss_dict = { 'epoch_total': epoch_total_loss / (dataset_size / batch_size), 'epoch_kl': epoch_kl_loss / (dataset_size / batch_size), 'epoch_pos': epoch_pos_loss / (dataset_size / batch_size), 'epoch_ori': epoch_ori_loss / (dataset_size / batch_size), 'epoch_recon': epoch_recon_loss / (dataset_size / batch_size) } total_loss.append(loss_dict) if epoch % args.save_freq == 0: print('\n--Saving model\n') print('time: ' + str(time.time() - start_time)) save_state(net=vae, torch_seed=torch_seed, np_seed=np_seed, args=args, fname=os.path.join( trained_model_path, args.model_name + '_epoch_' + str(epoch) + '.pt')) np.savez(os.path.join( trained_model_path, args.model_name + '_epoch_' + str(epoch) + '_loss.npz'), loss=np.asarray(total_loss)) print('Done!') save_state(net=vae, torch_seed=torch_seed, np_seed=np_seed, args=args, fname=os.path.join( trained_model_path, args.model_name + '_epoch_' + str(epoch) + '.pt'))