def main(args):
    #global hl
    if torch.cuda.is_available():
        torch.cuda.set_device(args.device)
    # environment setting
    cae = cae_identity
    mlp = MLP
    cpp_propagator = _sst_module.SystemPropagator()
    if args.env_type == 'pendulum':
        normalize = pendulum.normalize
        unnormalize = pendulum.unnormalize
        system = standard_cpp_systems.PSOPTPendulum()
        dynamics = None
        enforce_bounds = None
        step_sz = 0.002
        num_steps = 20

    elif args.env_type == 'cartpole':
        normalize = cart_pole.normalize
        unnormalize = cart_pole.unnormalize
        dynamics = cartpole.dynamics
        system = _sst_module.CartPole()
        enforce_bounds = cartpole.enforce_bounds
        step_sz = 0.002
        num_steps = 20
    elif args.env_type == 'cartpole_obs':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP
        cae = CAE_cartpole_voxel_2d
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20
        pos_indices = [0, 2]
        vel_indices = [1, 3]
    elif args.env_type == 'cartpole_obs_2':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP2
        cae = CAE_cartpole_voxel_2d
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20
        pos_indices = [0, 2]
        vel_indices = [1, 3]

    elif args.env_type == 'cartpole_obs_3':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP4
        cae = CAE_cartpole_voxel_2d
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20
        pos_indices = [0, 2]
        vel_indices = [1, 3]
        
    elif args.env_type == 'cartpole_obs_4_small':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP3
        cae = CAE_cartpole_voxel_2d
        
        # dynamics: None    -- without integration to dense trajectory
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        #dynamics = None
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20
        pos_indices = np.array([0, 2])
        vel_indices = np.array([1, 3])
    elif args.env_type == 'cartpole_obs_4_big':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP3
        cae = CAE_cartpole_voxel_2d
        
        # dynamics: None    -- without integration to dense trajectory
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        #dynamics = None
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20
        pos_indices = np.array([0, 2])
        vel_indices = np.array([1, 3])
    elif args.env_type == 'cartpole_obs_4_small_x_theta':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP3
        cae = CAE_cartpole_voxel_2d
        
        # dynamics: None    -- without integration to dense trajectory
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        #dynamics = None
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20
        pos_indices = np.array([0, 1])
        vel_indices = np.array([2, 3])
    elif args.env_type == 'cartpole_obs_4_big_x_theta':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP3
        cae = CAE_cartpole_voxel_2d
        
        # dynamics: None    -- without integration to dense trajectory
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        #dynamics = None
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20
        pos_indices = np.array([0, 1])
        vel_indices = np.array([2, 3])
    elif args.env_type == 'cartpole_obs_4_small_decouple_output':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP3
        cae = CAE_cartpole_voxel_2d
        
        # dynamics: None    -- without integration to dense trajectory
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        #dynamics = None
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20
        pos_indices = np.array([0, 2])
        vel_indices = np.array([1, 3])
    elif args.env_type == 'cartpole_obs_4_big_decouple_output':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP3
        cae = CAE_cartpole_voxel_2d
        
        # dynamics: None    -- without integration to dense trajectory
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        #dynamics = None
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20
        pos_indices = np.array([0, 2])
        vel_indices = np.array([1, 3])

        
        
    elif args.env_type == 'acrobot_obs':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP
        cae = CAE_acrobot_voxel_2d
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
        pos_indices = [0, 1]
        vel_indices = [2, 3]

    elif args.env_type == 'acrobot_obs_2':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP2
        cae = CAE_acrobot_voxel_2d_2
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
        pos_indices = [0, 1]
        vel_indices = [2, 3]

    elif args.env_type == 'acrobot_obs_3':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP3
        cae = CAE_acrobot_voxel_2d_2
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
        pos_indices = [0, 1]
        vel_indices = [2, 3]

    elif args.env_type == 'acrobot_obs_4':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP3
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
        pos_indices = [0, 1]
        vel_indices = [2, 3]

    elif args.env_type == 'acrobot_obs_5':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
        pos_indices = [0, 1]
        vel_indices = [2, 3]

    elif args.env_type == 'acrobot_obs_6':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP4
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
        pos_indices = [0, 1]
        vel_indices = [2, 3]

    elif args.env_type == 'acrobot_obs_7':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP5
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_8':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP6
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20



    # set loss for mpnet
    if args.loss == 'mse':
        #mpnet.loss_f = nn.MSELoss()
        def mse_loss(y1, y2):
            l = (y1 - y2) ** 2
            l = torch.mean(l, dim=0)  # sum alone the batch dimension, now the dimension is the same as input dimension
            return l
        loss_f_p = mse_loss
        loss_f_v = mse_loss

    elif args.loss == 'l1_smooth':
        #mpnet.loss_f = nn.SmoothL1Loss()
        def l1_smooth_loss(y1, y2):
            l1 = torch.abs(y1 - y2)
            cond = l1 < 1
            l = torch.where(cond, 0.5 * l1 ** 2, l1)
            l = torch.mean(l, dim=0)  # sum alone the batch dimension, now the dimension is the same as input dimension
            return l
        loss_f_p = l1_smooth_loss
        loss_f_v = l1_smooth_loss

    elif args.loss == 'mse_decoupled':
        def mse_decoupled(y1, y2):
            # for angle terms, wrap it to -pi~pi
            l_0 = torch.abs(y1[:,0] - y2[:,0]) ** 2
            l_1 = torch.abs(y1[:,1] - y2[:,1]) ** 2
            l_2 = torch.abs(y1[:,2] - y2[:,2]) # angular dimension
            l_3 = torch.abs(y1[:,3] - y2[:,3]) ** 2
            cond = (l_2 > 1.0) * (l_2 <= 2.0)
            l_2 = torch.where(cond, 2*1.0-l_2, l_2)
            l_2 = l_2 ** 2
            l_0 = torch.mean(l_0)
            l_1 = torch.mean(l_1)
            l_2 = torch.mean(l_2)
            l_3 = torch.mean(l_3)
            return torch.stack([l_0, l_1, l_2, l_3])
        loss_f_p = mse_decoupled
        loss_f_v = mse_decoupled

    elif args.loss == 'l1_smooth_decoupled':
        
        # this only is for cartpole, need to adapt to other systems
        #TODO
        def l1_smooth_decoupled(y1, y2):
            # for angle terms, wrap it to -pi~pi
            l_0 = torch.abs(y1[:,0] - y2[:,0])
            l_1 = torch.abs(y1[:,1] - y2[:,1]) # angular dimension
            cond = (l_1 > 1.0) * (l_1 <= 2.0)
            l_1 = torch.where(cond, 2*1.0-l_1, l_1)
            
            # then change to l1_smooth_loss
            cond = l_0 < 1
            l_0 = torch.where(cond, 0.5 * l_0 ** 2, l_0)
            cond = l_1 < 1
            l_1 = torch.where(cond, 0.5 * l_1 ** 2, l_1)
            
            l_0 = torch.mean(l_0)
            l_1 = torch.mean(l_1)
            return torch.stack([l_0, l_1])
        def l1_smooth_loss(y1, y2):
            l1 = torch.abs(y1 - y2)
            cond = l1 < 1
            l = torch.where(cond, 0.5 * l1 ** 2, l1)
            l = torch.mean(l, dim=0)  # sum alone the batch dimension, now the dimension is the same as input dimension
            return l
        loss_f_p = l1_smooth_decoupled
        loss_f_v = l1_smooth_loss


    if 'decouple_output' in args.env_type:
        print('mpnet using decoupled output')
        mpnet_pnet = KMPNet(args.total_input_size, args.AE_input_size, args.mlp_input_size, args.output_size//2,
                       cae, mlp, loss_f_p)
        mpnet_vnet = KMPNet(args.total_input_size, args.AE_input_size, args.mlp_input_size, args.output_size//2,
                       cae, mlp, loss_f_v)
    else:
        mpnet_pnet = KMPNet(args.total_input_size//2, args.AE_input_size, args.mlp_input_size, args.output_size//2,
                       cae, mlp, loss_f_p)
        mpnet_vnet = KMPNet(args.total_input_size//2, args.AE_input_size, args.mlp_input_size, args.output_size//2,
                       cae, mlp, loss_f_v)
        
    # load net
    # load previously trained model if start epoch > 0

    model_dir = args.model_dir
    if args.loss == 'mse':
        if args.multigoal == 0:
            model_dir = model_dir+args.env_type+"_lr%f_%s_step_%d/" % (args.learning_rate, args.opt, args.num_steps)
        else:
            model_dir = model_dir+args.env_type+"_lr%f_%s_step_%d_multigoal/" % (args.learning_rate, args.opt, args.num_steps)
    else:
        if args.multigoal == 0:
            model_dir = model_dir+args.env_type+"_lr%f_%s_loss_%s_step_%d/" % (args.learning_rate, args.opt, args.loss, args.num_steps)
        else:
            model_dir = model_dir+args.env_type+"_lr%f_%s_loss_%s_step_%d_multigoal/" % (args.learning_rate, args.opt, args.loss, args.num_steps)


    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    model_pnet_path='kmpnet_pnet_epoch_%d_direction_%d_step_%d.pkl' %(args.start_epoch, args.direction, args.num_steps)
    model_vnet_path='kmpnet_vnet_epoch_%d_direction_%d_step_%d.pkl' %(args.start_epoch, args.direction, args.num_steps)
    torch_seed, np_seed, py_seed = 0, 0, 0
    if args.start_epoch > 0:
        #load_net_state(mpnet, os.path.join(args.model_path, model_path))
        load_net_state(mpnet_pnet, os.path.join(model_dir, model_pnet_path))
        load_net_state(mpnet_vnet, os.path.join(model_dir, model_vnet_path))

        #torch_seed, np_seed, py_seed = load_seed(os.path.join(args.model_path, model_path))
        torch_seed, np_seed, py_seed = load_seed(os.path.join(model_dir, model_pnet_path))
        # set seed after loading
        torch.manual_seed(torch_seed)
        np.random.seed(np_seed)
        random.seed(py_seed)

    if torch.cuda.is_available():
        mpnet_pnet.cuda()
        mpnet_pnet.mlp.cuda()
        mpnet_pnet.encoder.cuda()

        mpnet_vnet.cuda()
        mpnet_vnet.mlp.cuda()
        mpnet_vnet.encoder.cuda()

        if args.opt == 'Adagrad':
            mpnet_pnet.set_opt(torch.optim.Adagrad, lr=args.learning_rate)
        elif args.opt == 'Adam':
            mpnet_pnet.set_opt(torch.optim.Adam, lr=args.learning_rate)
        elif args.opt == 'SGD':
            mpnet_pnet.set_opt(torch.optim.SGD, lr=args.learning_rate, momentum=0.9)
        elif args.opt == 'ASGD':
            mpnet_pnet.set_opt(torch.optim.ASGD, lr=args.learning_rate)

            
        if args.opt == 'Adagrad':
            mpnet_vnet.set_opt(torch.optim.Adagrad, lr=args.learning_rate)
        elif args.opt == 'Adam':
            mpnet_vnet.set_opt(torch.optim.Adam, lr=args.learning_rate)
        elif args.opt == 'SGD':
            mpnet_vnet.set_opt(torch.optim.SGD, lr=args.learning_rate, momentum=0.9)
        elif args.opt == 'ASGD':
            mpnet_vnet.set_opt(torch.optim.ASGD, lr=args.learning_rate)

            
            
        if args.start_epoch > 0:
            #load_opt_state(mpnet, os.path.join(args.model_path, model_path))
            load_opt_state(mpnet_pnet, os.path.join(model_dir, model_path))
            load_opt_state(mpnet_vnet, os.path.join(model_dir, model_path))


    # load train and test data
    print('loading...')
    obs, waypoint_dataset, waypoint_targets, env_indices, \
    _, _, _, _ = data_loader.load_train_dataset(N=args.no_env, NP=args.no_motion_paths,
                                                data_folder=args.path_folder, obs_f=True,
                                                direction=args.direction,
                                                dynamics=dynamics, enforce_bounds=enforce_bounds,
                                                system=system, step_sz=step_sz,
                                                num_steps=args.num_steps, multigoal=args.multigoal)
    # randomize the dataset before training
    data=list(zip(waypoint_dataset,waypoint_targets,env_indices))
    random.shuffle(data)
    dataset,targets,env_indices=list(zip(*data))
    dataset = list(dataset)
    dataset = np.array(dataset)
    targets = np.array(targets)
    print(np.concatenate([pos_indices, pos_indices+args.total_input_size//2]))
    p_dataset = dataset[:, np.concatenate([pos_indices, pos_indices+args.total_input_size//2])]
    v_dataset = dataset[:, np.concatenate([vel_indices, vel_indices+args.total_input_size//2])]
    if 'decouple_output' in args.env_type:
        # only decouple output
        print('only decouple output but not input')
        p_dataset = dataset
        v_dataset = dataset
    print(p_dataset.shape)
    print(v_dataset.shape)
    
    
    
    p_targets = targets[:,pos_indices]
    v_targets = targets[:,vel_indices]   # this is only for cartpole
                                # TODO: add string for choosing env

    p_targets = list(p_targets)
    v_targets = list(v_targets)
    #targets = list(targets)
    env_indices = list(env_indices)
    dataset = np.array(dataset)
    #targets = np.array(targets)
    env_indices = np.array(env_indices)

    # use 5% as validation dataset
    val_len = int(len(dataset) * 0.05)
    val_p_dataset = p_dataset[-val_len:]
    val_v_dataset = v_dataset[-val_len:]
    val_p_targets = p_targets[-val_len:]
    val_v_targets = v_targets[-val_len:]
    val_env_indices = env_indices[-val_len:]

    p_dataset = p_dataset[:-val_len]
    v_dataset = v_dataset[:-val_len]
    p_targets = p_targets[:-val_len]
    v_targets = v_targets[:-val_len]
    env_indices = env_indices[:-val_len]

    # Train the Models
    print('training...')
    if args.loss == 'mse':
        if args.multigoal == 0:
            writer_fname = 'pos_vel_%s_%f_%s_direction_%d_step_%d' % (args.env_type, args.learning_rate, args.opt, args.direction, args.num_steps, )
        else:
            writer_fname = 'pos_vel_%s_%f_%s_direction_%d_step_%d_multigoal' % (args.env_type, args.learning_rate, args.opt, args.direction, args.num_steps, )
    else:
        if args.multigoal == 0:
            writer_fname = 'pos_vel_%s_%f_%s_direction_%d_step_%d_loss_%s' % (args.env_type, args.learning_rate, args.opt, args.direction, args.num_steps, args.loss, )
        else:
            writer_fname = 'pos_vel_%s_%f_%s_direction_%d_step_%d_loss_%s_multigoal' % (args.env_type, args.learning_rate, args.opt, args.direction, args.num_steps, args.loss, )


    writer = SummaryWriter('./runs/'+writer_fname)
    record_i = 0
    val_record_i = 0
    p_loss_avg_i = 0
    p_val_loss_avg_i = 0
    p_loss_avg = 0.
    p_val_loss_avg = 0.
    v_loss_avg_i = 0
    v_val_loss_avg_i = 0
    v_loss_avg = 0.
    v_val_loss_avg = 0.

    loss_steps = 100  # record every 100 loss
    
    
    world_size = np.array(args.world_size)
    pos_world_size = list(world_size[pos_indices])
    vel_world_size = list(world_size[vel_indices])
    

    
    for epoch in range(args.start_epoch+1,args.num_epochs+1):
        print('epoch' + str(epoch))
        val_i = 0
        for i in range(0,len(p_dataset),args.batch_size):
            print('epoch: %d, training... path: %d' % (epoch, i+1))
            p_dataset_i = p_dataset[i:i+args.batch_size]
            v_dataset_i = v_dataset[i:i+args.batch_size]
            p_targets_i = p_targets[i:i+args.batch_size]
            v_targets_i = v_targets[i:i+args.batch_size]
            env_indices_i = env_indices[i:i+args.batch_size]
            # record
            p_bi = p_dataset_i.astype(np.float32)
            v_bi = v_dataset_i.astype(np.float32)
            print('p_bi shape:')
            print(p_bi.shape)
            print('v_bi shape:')
            print(v_bi.shape)
            p_bt = p_targets_i
            v_bt = v_targets_i
            p_bi = torch.FloatTensor(p_bi)
            v_bi = torch.FloatTensor(v_bi)
            p_bt = torch.FloatTensor(p_bt)
            v_bt = torch.FloatTensor(v_bt)

            # edit: disable this for investigation of the good weights for training, and for wrapping
            if 'decouple_output' in args.env_type:
                print('using normalizatino of decoupled output')
                # only decouple output but not input
                p_bi, v_bi, p_bt, v_bt = normalize(p_bi, args.world_size), normalize(v_bi, args.world_size), normalize(p_bt, pos_world_size), normalize(v_bt, vel_world_size)
            else:
                p_bi, v_bi, p_bt, v_bt = normalize(p_bi, pos_world_size), normalize(v_bi, vel_world_size), normalize(p_bt, pos_world_size), normalize(v_bt, vel_world_size)


            mpnet_pnet.zero_grad()
            mpnet_vnet.zero_grad()

            p_bi=to_var(p_bi)
            v_bi=to_var(v_bi)
            p_bt=to_var(p_bt)
            v_bt=to_var(v_bt)

            if obs is None:
                bobs = None
            else:
                bobs = obs[env_indices_i].astype(np.float32)
                bobs = torch.FloatTensor(bobs)
                bobs = to_var(bobs)
            print('-------pnet-------')
            print('before training losses:')
            print(mpnet_pnet.loss(mpnet_pnet(p_bi, bobs), p_bt))
            mpnet_pnet.step(p_bi, bobs, p_bt)
            print('after training losses:')
            print(mpnet_pnet.loss(mpnet_pnet(p_bi, bobs), p_bt))
            p_loss = mpnet_pnet.loss(mpnet_pnet(p_bi, bobs), p_bt)
            #update_line(hl, ax, [i//args.batch_size, loss.data.numpy()])
            p_loss_avg += p_loss.cpu().data
            p_loss_avg_i += 1
            
            print('-------vnet-------')
            print('before training losses:')
            print(mpnet_vnet.loss(mpnet_vnet(v_bi, bobs), v_bt))
            mpnet_vnet.step(v_bi, bobs, v_bt)
            print('after training losses:')
            print(mpnet_vnet.loss(mpnet_vnet(v_bi, bobs), v_bt))
            v_loss = mpnet_vnet.loss(mpnet_vnet(v_bi, bobs), v_bt)
            #update_line(hl, ax, [i//args.batch_size, loss.data.numpy()])
            v_loss_avg += v_loss.cpu().data
            v_loss_avg_i += 1
            

            if p_loss_avg_i >= loss_steps:
                p_loss_avg = p_loss_avg / p_loss_avg_i
                writer.add_scalar('p_train_loss_0', p_loss_avg[0], record_i)
                writer.add_scalar('p_train_loss_1', p_loss_avg[1], record_i)

                v_loss_avg = v_loss_avg / v_loss_avg_i
                writer.add_scalar('v_train_loss_0', v_loss_avg[0], record_i)
                writer.add_scalar('v_train_loss_1', v_loss_avg[1], record_i)

                record_i += 1
                p_loss_avg = 0.
                p_loss_avg_i = 0

                v_loss_avg = 0.
                v_loss_avg_i = 0

                
            # validation
            # calculate the corresponding batch in val_dataset
            p_dataset_i = val_p_dataset[val_i:val_i+args.batch_size]
            v_dataset_i = val_v_dataset[val_i:val_i+args.batch_size]

            p_targets_i = val_p_targets[val_i:val_i+args.batch_size]
            v_targets_i = val_v_targets[val_i:val_i+args.batch_size]

            env_indices_i = val_env_indices[val_i:val_i+args.batch_size]
            val_i = val_i + args.batch_size
            if val_i > val_len:
                val_i = 0
            # record
            p_bi = p_dataset_i.astype(np.float32)
            v_bi = v_dataset_i.astype(np.float32)

            print('p_bi shape:')
            print(p_bi.shape)
            print('v_bi shape:')
            print(v_bi.shape)

            p_bt = p_targets_i
            v_bt = v_targets_i
            p_bi = torch.FloatTensor(p_bi)
            v_bi = torch.FloatTensor(v_bi)

            p_bt = torch.FloatTensor(p_bt)
            v_bt = torch.FloatTensor(v_bt)
            if 'decouple_output' in args.env_type:
                # only decouple output but not input
                p_bi, v_bi, p_bt, v_bt = normalize(p_bi, args.world_size), normalize(v_bi, args.world_size), normalize(p_bt, pos_world_size), normalize(v_bt, vel_world_size)
            else:
                p_bi, v_bi, p_bt, v_bt = normalize(p_bi, pos_world_size), normalize(v_bi, vel_world_size), normalize(p_bt, pos_world_size), normalize(v_bt, vel_world_size)
                
            p_bi=to_var(p_bi)
            v_bi=to_var(v_bi)
            p_bt=to_var(p_bt)
            v_bt=to_var(v_bt)

            if obs is None:
                bobs = None
            else:
                bobs = obs[env_indices_i].astype(np.float32)
                bobs = torch.FloatTensor(bobs)
                bobs = to_var(bobs)
            print('-------pnet loss--------')
            p_loss = mpnet_pnet.loss(mpnet_pnet(p_bi, bobs), p_bt)
            print('validation loss: ' % (p_loss.cpu().data))

            p_val_loss_avg += p_loss.cpu().data
            p_val_loss_avg_i += 1

            print('-------vnet loss--------')
            v_loss = mpnet_vnet.loss(mpnet_vnet(v_bi, bobs), v_bt)
            print('validation loss: ' % (v_loss.cpu().data))

            v_val_loss_avg += v_loss.cpu().data
            v_val_loss_avg_i += 1

            
            if p_val_loss_avg_i >= loss_steps:
                p_val_loss_avg = p_val_loss_avg / p_val_loss_avg_i
                writer.add_scalar('p_val_loss_0', p_val_loss_avg[0], val_record_i)
                writer.add_scalar('p_val_loss_1', p_val_loss_avg[1], val_record_i)
                v_val_loss_avg = v_val_loss_avg / v_val_loss_avg_i
                writer.add_scalar('v_val_loss_0', v_val_loss_avg[0], val_record_i)
                writer.add_scalar('v_val_loss_1', v_val_loss_avg[1], val_record_i)

                
                val_record_i += 1
                p_val_loss_avg = 0.
                p_val_loss_avg_i = 0
                
                v_val_loss_avg = 0.
                v_val_loss_avg_i = 0

        # Save the models
        if epoch > 0 and epoch % 50 == 0:
            model_pnet_path='kmpnet_pnet_epoch_%d_direction_%d_step_%d.pkl' %(epoch, args.direction, args.num_steps)
            model_vnet_path='kmpnet_vnet_epoch_%d_direction_%d_step_%d.pkl' %(epoch, args.direction, args.num_steps)
            #save_state(mpnet, torch_seed, np_seed, py_seed, os.path.join(args.model_path,model_path))
            save_state(mpnet_pnet, torch_seed, np_seed, py_seed, os.path.join(model_dir,model_pnet_path))
            save_state(mpnet_vnet, torch_seed, np_seed, py_seed, os.path.join(model_dir,model_vnet_path))

    writer.export_scalars_to_json("./all_scalars.json")
    writer.close()
Exemplo n.º 2
0
def main(args):
    #global hl

    if torch.cuda.is_available():
        torch.cuda.set_device(args.device)
    # environment setting
    multigoal = False
    cpp_propagator = _sst_module.SystemPropagator()
    if args.env_type == 'pendulum':
        normalize = pendulum.normalize
        unnormalize = pendulum.unnormalize
        system = standard_cpp_systems.PSOPTPendulum()
        dynamics = None
        enforce_bounds = None
        step_sz = 0.002
        num_steps = 20

    elif args.env_type == 'cartpole':
        normalize = cart_pole.normalize
        unnormalize = cart_pole.unnormalize
        dynamics = cartpole.dynamics
        system = _sst_module.CartPole()
        enforce_bounds = cartpole.enforce_bounds
        step_sz = 0.002
        num_steps = 20
    elif args.env_type == 'cartpole_obs':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.CartPole()
        dynamics = cartpole.dynamics
        enforce_bounds = cartpole.enforce_bounds
        step_sz = 0.002
        num_steps = 20
        cae = cae_identity
        mlp = MLP
    elif args.env_type == 'cartpole_obs_4':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP3_no_dropout
        cae = CAE_cartpole_voxel_2d
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        multigoal = False
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20
    elif args.env_type == 'cartpole_obs_4_multigoal':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP3_no_dropout
        cae = CAE_cartpole_voxel_2d
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        #dynamics = None
        multigoal = True

        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20

    elif args.env_type == 'acrobot_obs':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP
        cae = CAE_acrobot_voxel_2d
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_2':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP2
        cae = CAE_acrobot_voxel_2d_2
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_3':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP3
        cae = CAE_acrobot_voxel_2d_2
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_4':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP3
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_5':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_6':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP4
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_7':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP5
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_8':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP6
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20

    # set loss for mpnet
    if args.loss == 'mse':
        #mpnet.loss_f = nn.MSELoss()
        def mse_loss(y1, y2):
            l = (y1 - y2)**2
            l = torch.mean(
                l, dim=0
            )  # sum alone the batch dimension, now the dimension is the same as input dimension
            return l

        loss_f = mse_loss

    elif args.loss == 'l1_smooth':
        #mpnet.loss_f = nn.SmoothL1Loss()
        def l1_smooth_loss(y1, y2):
            l1 = torch.abs(y1 - y2)
            cond = l1 < 1
            l = torch.where(cond, 0.5 * l1**2, l1)
            l = torch.mean(
                l, dim=0
            )  # sum alone the batch dimension, now the dimension is the same as input dimension

        loss_f = l1_smooth_loss

    elif args.loss == 'mse_decoupled':

        def mse_decoupled(y1, y2):
            # for angle terms, wrap it to -pi~pi
            l_0 = torch.abs(y1[:, 0] - y2[:, 0])**2
            l_1 = torch.abs(y1[:, 1] - y2[:, 1])**2
            l_2 = torch.abs(y1[:, 2] - y2[:, 2])  # angular dimension
            l_3 = torch.abs(y1[:, 3] - y2[:, 3])**2

            cond = (l_2 > 1.0) * (l_2 <= 2.0
                                  )  # np.pi after normalization is 1.0
            l_2 = torch.where(cond, 2.0 - l_2, l_2)
            l_2 = l_2**2
            l_0 = torch.mean(l_0)
            l_1 = torch.mean(l_1)
            l_2 = torch.mean(l_2)
            l_3 = torch.mean(l_3)
            return torch.stack([l_0, l_1, l_2, l_3])

        loss_f = mse_decoupled

    mpnet = KMPNet(args.total_input_size, args.AE_input_size,
                   args.mlp_input_size, args.output_size, cae, mlp, loss_f)
    # load net
    # load previously trained model if start epoch > 0
    model_dir = args.model_dir
    model_dir = model_dir + 'cost_' + args.env_type + "_lr%f_%s_step_%d/" % (
        args.learning_rate, args.opt, args.num_steps)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    model_path = 'cost_kmpnet_epoch_%d_direction_%d_step_%d.pkl' % (
        args.start_epoch, args.direction, args.num_steps)
    torch_seed, np_seed, py_seed = 0, 0, 0
    if args.start_epoch > 0:
        #load_net_state(mpnet, os.path.join(args.model_path, model_path))
        load_net_state(mpnet, os.path.join(model_dir, model_path))
        #torch_seed, np_seed, py_seed = load_seed(os.path.join(args.model_path, model_path))
        torch_seed, np_seed, py_seed = load_seed(
            os.path.join(model_dir, model_path))
        # set seed after loading
        torch.manual_seed(torch_seed)
        np.random.seed(np_seed)
        random.seed(py_seed)

    if torch.cuda.is_available():
        mpnet.cuda()
        mpnet.mlp.cuda()
        mpnet.encoder.cuda()
        if args.opt == 'Adagrad':
            mpnet.set_opt(torch.optim.Adagrad, lr=args.learning_rate)
        elif args.opt == 'Adam':
            mpnet.set_opt(torch.optim.Adam, lr=args.learning_rate)
        elif args.opt == 'SGD':
            mpnet.set_opt(torch.optim.SGD, lr=args.learning_rate, momentum=0.9)
        elif args.opt == 'ASGD':
            mpnet.set_opt(torch.optim.ASGD, lr=args.learning_rate)
    if args.start_epoch > 0:
        #load_opt_state(mpnet, os.path.join(args.model_path, model_path))
        load_opt_state(mpnet, os.path.join(model_dir, model_path))

    # load train and test data
    print('loading...')
    obs, cost_dataset, cost_targets, env_indices, \
    _, _, _, _ = data_loader.load_train_dataset_cost(N=args.no_env, NP=args.no_motion_paths,
                                                data_folder=args.path_folder, obs_f=True,
                                                direction=args.direction,
                                                dynamics=dynamics, enforce_bounds=enforce_bounds,
                                                system=system, step_sz=step_sz, num_steps=args.num_steps,
                                                multigoal=multigoal)
    # randomize the dataset before training
    data = list(zip(cost_dataset, cost_targets, env_indices))
    random.shuffle(data)
    dataset, targets, env_indices = list(zip(*data))
    dataset = list(dataset)
    targets = list(targets)
    env_indices = list(env_indices)
    dataset = np.array(dataset)
    targets = np.array(targets)
    env_indices = np.array(env_indices)

    # use 5% as validation dataset
    val_len = int(len(dataset) * 0.05)
    val_dataset = dataset[-val_len:]
    val_targets = targets[-val_len:]
    val_env_indices = env_indices[-val_len:]

    dataset = dataset[:-val_len]
    targets = targets[:-val_len]
    env_indices = env_indices[:-val_len]

    # Train the Models
    print('training...')
    writer_fname = 'cost_%s_%f_%s_direction_%d_step_%d' % (
        args.env_type, args.learning_rate, args.opt, args.direction,
        args.num_steps)
    writer = SummaryWriter('./runs/' + writer_fname)
    record_i = 0
    val_record_i = 0
    loss_avg_i = 0
    val_loss_avg_i = 0
    loss_avg = 0.
    val_loss_avg = 0.
    loss_steps = 100  # record every 100 loss
    for epoch in range(args.start_epoch + 1, args.num_epochs + 1):
        print('epoch' + str(epoch))
        val_i = 0
        for i in range(0, len(dataset), args.batch_size):
            print('epoch: %d, training... path: %d' % (epoch, i + 1))
            dataset_i = dataset[i:i + args.batch_size]
            targets_i = targets[i:i + args.batch_size]
            env_indices_i = env_indices[i:i + args.batch_size]
            # record
            bi = dataset_i.astype(np.float32)
            print('bi shape:')
            print(bi.shape)
            bt = targets_i
            bi = torch.FloatTensor(bi)
            bt = torch.FloatTensor(bt)
            bi = normalize(bi, args.world_size)
            mpnet.zero_grad()
            bi = to_var(bi)
            bt = to_var(bt)
            if obs is None:
                bobs = None
            else:
                bobs = obs[env_indices_i].astype(np.float32)
                bobs = torch.FloatTensor(bobs)
                bobs = to_var(bobs)
            print('before training losses:')
            print(mpnet.loss(mpnet(bi, bobs), bt))
            mpnet.step(bi, bobs, bt)
            print('after training losses:')
            print(mpnet.loss(mpnet(bi, bobs), bt))
            loss = mpnet.loss(mpnet(bi, bobs), bt)
            #update_line(hl, ax, [i//args.batch_size, loss.data.numpy()])
            loss_avg += loss.cpu().data
            loss_avg_i += 1
            if loss_avg_i >= loss_steps:
                loss_avg = loss_avg / loss_avg_i
                writer.add_scalar('train_loss', loss_avg, record_i)
                record_i += 1
                loss_avg = 0.
                loss_avg_i = 0

            # validation
            # calculate the corresponding batch in val_dataset
            dataset_i = val_dataset[val_i:val_i + args.batch_size]
            targets_i = val_targets[val_i:val_i + args.batch_size]
            env_indices_i = val_env_indices[val_i:val_i + args.batch_size]
            val_i = val_i + args.batch_size
            if val_i > val_len:
                val_i = 0
            # record
            bi = dataset_i.astype(np.float32)
            print('bi shape:')
            print(bi.shape)
            bt = targets_i
            bi = torch.FloatTensor(bi)
            bt = torch.FloatTensor(bt)
            bi = normalize(bi, args.world_size)
            bi = to_var(bi)
            bt = to_var(bt)
            if obs is None:
                bobs = None
            else:
                bobs = obs[env_indices_i].astype(np.float32)
                bobs = torch.FloatTensor(bobs)
                bobs = to_var(bobs)
            loss = mpnet.loss(mpnet(bi, bobs), bt)
            print('validation loss: %f' % (loss.cpu().data))

            val_loss_avg += loss.cpu().data
            val_loss_avg_i += 1
            if val_loss_avg_i >= loss_steps:
                val_loss_avg = val_loss_avg / val_loss_avg_i
                writer.add_scalar('val_loss', val_loss_avg, val_record_i)
                val_record_i += 1
                val_loss_avg = 0.
                val_loss_avg_i = 0
        # Save the models
        if epoch > 0 and epoch % 50 == 0:
            model_path = 'cost_kmpnet_epoch_%d_direction_%d_step_%d.pkl' % (
                epoch, args.direction, args.num_steps)
            #save_state(mpnet, torch_seed, np_seed, py_seed, os.path.join(args.model_path,model_path))
            save_state(mpnet, torch_seed, np_seed, py_seed,
                       os.path.join(model_dir, model_path))
    writer.export_scalars_to_json("./all_scalars.json")
    writer.close()