コード例 #1
0
ファイル: agents.py プロジェクト: arnaudvl/world-models-ppo
    def __init__(self,
                 action_dim: int,
                 output_dim: int,
                 vae_path: str = './vae/model/best.tar',
                 mdnrnn_path: str = './mdnrnn/model/best.tar',
                 activation: Callable = torch.relu,
                 output_activation: Union[Callable, str] = None,
                 output_squeeze: bool = False) -> None:
        super(WorldModel, self).__init__()

        # define VAE model
        self.latent_size = LATENT_SIZE
        self.vae = VAE(CHANNELS, LATENT_SIZE)
        vae_state = torch.load(vae_path)
        self.vae.load_state_dict(vae_state['state_dict'])
        for param in self.vae.parameters():
            param.requires_grad = False

        # define MDNRNN model
        self.n_gauss = N_GAUSS
        self.mdnrnn = MDNRNN(LATENT_SIZE, action_dim, HIDDEN_SIZE, N_GAUSS, rewards_terminal=False)
        mdnrnn_state = torch.load(mdnrnn_path)
        self.mdnrnn.load_state_dict(mdnrnn_state['state_dict'])
        for param in self.mdnrnn.parameters():
            param.requires_grad = False

        # controller
        ctr_size = LATENT_SIZE + N_GAUSS + 2 * (LATENT_SIZE * N_GAUSS)
        self.fc1 = nn.Linear(ctr_size, 512)
        self.fc2 = nn.Linear(512, output_dim)

        self.activation = activation
        self.output_activation = output_activation
        self.output_squeeze = output_squeeze
コード例 #2
0
ファイル: agents.py プロジェクト: arnaudvl/world-models-ppo
class WorldModel(nn.Module):
    def __init__(self,
                 action_dim: int,
                 output_dim: int,
                 vae_path: str = './vae/model/best.tar',
                 mdnrnn_path: str = './mdnrnn/model/best.tar',
                 activation: Callable = torch.relu,
                 output_activation: Union[Callable, str] = None,
                 output_squeeze: bool = False) -> None:
        super(WorldModel, self).__init__()

        # define VAE model
        self.latent_size = LATENT_SIZE
        self.vae = VAE(CHANNELS, LATENT_SIZE)
        vae_state = torch.load(vae_path)
        self.vae.load_state_dict(vae_state['state_dict'])
        for param in self.vae.parameters():
            param.requires_grad = False

        # define MDNRNN model
        self.n_gauss = N_GAUSS
        self.mdnrnn = MDNRNN(LATENT_SIZE, action_dim, HIDDEN_SIZE, N_GAUSS, rewards_terminal=False)
        mdnrnn_state = torch.load(mdnrnn_path)
        self.mdnrnn.load_state_dict(mdnrnn_state['state_dict'])
        for param in self.mdnrnn.parameters():
            param.requires_grad = False

        # controller
        ctr_size = LATENT_SIZE + N_GAUSS + 2 * (LATENT_SIZE * N_GAUSS)
        self.fc1 = nn.Linear(ctr_size, 512)
        self.fc2 = nn.Linear(512, output_dim)

        self.activation = activation
        self.output_activation = output_activation
        self.output_squeeze = output_squeeze

    def forward(self, x: torch.Tensor, a: torch.Tensor) -> torch.Tensor:
        # VAE
        _, mu, logsigma = self.vae(x)
        latent = (mu + logsigma.exp() * torch.randn_like(mu)).view(-1, self.latent_size)

        # MDNRNN
        mus, sigmas, logpi = self.mdnrnn(a.unsqueeze(0), latent.unsqueeze(0))

        # reshape
        mus = torch.squeeze(mus, dim=0).view(-1, self.n_gauss * self.latent_size)
        sigmas = torch.squeeze(sigmas, dim=0).view(-1, self.n_gauss * self.latent_size)
        logpi = torch.squeeze(logpi, dim=0).view(-1, self.n_gauss)

        # controller
        ctr_in = [latent, mus, sigmas, logpi]
        x = torch.cat(ctr_in, dim=1)
        x = self.activation(self.fc1(x))
        if self.output_activation is None:
            x = self.fc2(x)
        else:
            x = self.output_activation(self.fc2(x))
        return x.squeeze() if self.output_squeeze else x
コード例 #3
0
    def __init__(self, env_):
        self.env = env_
        self.dev = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')
        self.action_space = env_.action_space
        # self.observation_space = (1, 32)

        # vae
        self.vae = VAE()
        model_path = "./vae/vae.pth"
        self.vae.load_state_dict(torch.load(model_path))
        self.vae.to(self.dev)
        self._gen_id = 0  # 何回目のgenerateかを保持
        self._frames = []  # mp4生成用にframeを保存
コード例 #4
0
def main():
    config = get_config()
    logging.basicConfig(format='%(asctime)s | %(message)s',
                        handlers=[
                            logging.FileHandler(
                                os.path.join(config.log_root,
                                             config.log_name)),
                            logging.StreamHandler()
                        ],
                        level=logging.INFO)

    transform = transforms.Compose([
        transforms.Scale(config.image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataset = datasets.CIFAR10(root=config.data_root,
                               download=True,
                               transform=transform,
                               train=True)
    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=config.batch_size,
                                               shuffle=True,
                                               num_workers=0,
                                               pin_memory=True)
    dataset = datasets.CIFAR10(root=config.data_root,
                               download=True,
                               transform=transform,
                               train=False)
    test_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=config.batch_size,
                                              shuffle=True,
                                              num_workers=0,
                                              pin_memory=True)

    net = VAE(config.image_size)

    trainer = Trainer(model=net,
                      train_loader=train_loader,
                      test_loader=test_loader,
                      optimizer=Adam(net.parameters(),
                                     lr=0.0002,
                                     betas=(0.5, 0.999)),
                      loss_function=loss_function,
                      device='cpu')

    for epoch in range(0, config.epochs):
        trainer.train(epoch, config.log_metrics_every)
        trainer.test(epoch, config.log_metrics_every)
コード例 #5
0
def encode_main(args):
    ''' Use vae to encode data '''

    dataset = args.dataset.split('_')[0]
    op = args.dataset.split('_')[1]
    model_path = args.model_path

    # model_path = 'checkpoint/__/vae_lstm-300_loss-0.01565_val_loss-0.01560.h5'

    (X_train, _, X_val, _) = get_train_data_('data/'+dataset+'/x_train_'+dataset+'_'+op+'.npy',
                                                       'data/'+dataset+'/y_train_'+dataset+'_'+op+'.npy',
                                                       'data/'+dataset+'/x_val_'+dataset+'_'+op+'.npy',
                                                       'data/'+dataset+'/y_val_'+dataset+'_'+op+'.npy')

    vae_obj = VAE(batch_size=1,
                  latent_dim=100,
                  epsilon_std=1.)

    # Encode X_train
    # X_train_enc = encode(X_train, vae_obj, model_path='model/vae.h5')
    X_train_enc = encode(
        X_train, vae_obj, model_path=model_path)
    # Save this representations
    np.save('data/'+dataset+'/x_train_'+dataset +
            '_'+op+'_enc.npy', X_train_enc)

    # Encode X_val
    X_val_enc = encode(X_val, vae_obj, model_path=model_path)
    # Save this representations
    np.save('data/'+dataset+'/x_val_'+dataset+'_'+op+'_enc.npy', X_val_enc)
コード例 #6
0
def show_reconstruction(cfg, args):
    vae_path = args['--vae']
    print('vae_path: %s' % vae_path)

    if not vae_path:
        print('Error: No vae path specified')
        sys.exit(0)

    # init vae
    print('Initializing vae...')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    vae = VAE(image_channels=cfg.IMAGE_CHANNELS, z_dim=cfg.VARIANTS_SIZE)
    vae.load_state_dict(torch.load(vae_path,
                                   map_location=torch.device(device)))
    vae.to(device).eval()

    for i in range(cfg.TIME_STEPS):

        z = torch.load('z_tensor.pt')

        reconst = vae.decode(z)
        reconst = reconst.detach().cpu()[0].numpy()
        reconst = np.transpose(np.uint8(reconst * 255), [1, 2, 0])

        reconst_image = F_.to_pil_image(reconst)
        imgplot = plt.imshow(reconst_image)

        plt.pause(0.05)

        time.sleep(0.1)
コード例 #7
0
ファイル: env.py プロジェクト: Ryunosuke-Ikeda/RL_Car0
 def __init__(self, env_):
     self.env = env_
     self.dev = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
     self.action_space = env_.action_space
     self._state_steps = 3
     self.observation_space = (80, 160, 3*self._state_steps)
     print("obs shape {}".format(self.observation_space))
     self.state_shape = 32*self._state_steps
     # vae
     self.vae = VAE()
     model_path = "./vae/vae_ikeda_ver1_32.pth"
     self.vae.load_state_dict(torch.load(model_path))
     self.vae.to(self.dev)
     self._gen_id = 0  # 何回目のgenerateかを保持
     self._frames = []  # mp4生成用にframeを保存
     # pre processing
     self._state_frames = deque(maxlen=self._state_steps)  # 変換前のframe
     self._gen_id = 0  # 何回目のgenerateかを保持
     self._frames = []  # mp4生成用にframeを保存
     self._step_repeat_times = 2
コード例 #8
0
def observe_and_learn(cfg, model_path, vae_path=None):
    global ctr

    time_steps = 5000
    time_steps = 10000

    try:

        time.sleep(5)

        vae = None
        device = None

        if vae_path:
            # init vae
            print('Initializing vae...')
            device = torch.device(
                'cuda' if torch.cuda.is_available() else 'cpu')
            vae = VAE(image_channels=cfg.IMAGE_CHANNELS,
                      z_dim=cfg.VARIANTS_SIZE)
            vae.load_state_dict(
                torch.load(vae_path, map_location=torch.device(device)))
            vae.to(device).eval()

        # create agent; wrapper for environment; later we can add vae to the agent
        agent = DonkeyAgent(cam,
                            time_step=0.05,
                            frame_skip=1,
                            env_type='simulator',
                            controller=ctr,
                            vae=vae,
                            device=device)
        print('DonkeyAgent created...')

        model = CustomSAC(CustomSACPolicy,
                          agent,
                          verbose=cfg.VERBOSE,
                          batch_size=cfg.BATCH_SIZE,
                          buffer_size=cfg.BUFFER_SIZE,
                          learning_starts=cfg.LEARNING_STARTS,
                          gradient_steps=cfg.GRADIENT_STEPS,
                          train_freq=cfg.TRAIN_FREQ,
                          ent_coef=cfg.ENT_COEF,
                          learning_rate=cfg.LEARNING_RATE)
        print('CustomSAC Initialized.')

        print('learning...')

        ctr.mode = 'local'

        model.learn(total_timesteps=cfg.TIME_STEPS,
                    log_interval=cfg.LOG_INTERVAL)

        model.save(model_path)

        print('Model finished.')

    except Exception as e:
        print('error:no new model generated. %s' % e)
        traceback.print_exc()
コード例 #9
0
def main():
    args = parse_args()
    input_size = (settings.reduced_image_channels,
                  settings.reduced_image_width, settings.reduced_image_height)
    vae = VAE(input_size=input_size,
              latent_dim=settings.vae_latent_dim).to(settings.device)
    savefile = Path(args.savefile)
    if savefile.exists():
        vae.load_state_dict(torch.load(f'{savefile}'))
        vae.eval()
    run(vae, savefile)
コード例 #10
0
ファイル: main.py プロジェクト: szrlee/vae-anomaly-detector
def train(config, trainloader, devloader=None):
    current_time = datetime.now().strftime('%Y_%m_%d_%H_%M')
    checkpoint_directory = os.path.join(
        config['paths']['checkpoints_directory'],
        '{}{}/'.format(config['model']['name'],
                       config['model']['config_id']), current_time)
    os.makedirs(checkpoint_directory, exist_ok=True)

    input_dim = trainloader.dataset.input_dim_
    vae = VAE(input_dim, config, checkpoint_directory)
    vae.to(config['model']['device'])
    vae.fit(trainloader)
コード例 #11
0
ファイル: eval.py プロジェクト: szrlee/vae-anomaly-detector
def eval(config, testloader):
    storage = {
        # 'll_precision': None, 'll_recall': None,
        'log_densities': None,
        'params': None,
        'ground_truth': None
    }
    input_dim = testloader.dataset.input_dim_
    vae = VAE(input_dim, config, checkpoint_directory=None)
    vae.to(config['model']['device'])
    if args.restore_filename is not None:
        vae.restore_model(args.restore_filename, epoch=None)
    vae.eval()
    precisions, recalls, all_log_densities = [], [], []
    # z sample sizes: 100
    for i in range(100):
        print("evaluation round {}".format(i))
        _, _, precision, recall, log_densities, ground_truth = vae.evaluate(
            testloader)
        precisions.append(precision)
        recalls.append(recall)
        all_log_densities.append(np.expand_dims(log_densities, axis=1))
    print(mean_confidence_interval(precisions))
    print(mean_confidence_interval(recalls))
    all_log_densities = np.concatenate(all_log_densities, axis=1)
    # log sum exponential
    storage['log_densities'] = logsumexp(all_log_densities,
                                         axis=1) - np.log(100)
    storage['ground_truth'] = ground_truth
    # storage['ll_precision'] = mean_confidence_interval(precisions)
    # storage['ll_recall'] = mean_confidence_interval(recalls)
    # storage['params'] = self._get_parameters(testloader)
    pkl_filename = './results/test/{}{}/{}.pkl'.format(config['model']['name'], \
      config['model']['config_id'], args.restore_filename)
    os.makedirs(os.path.dirname(pkl_filename), exist_ok=True)
    with open(pkl_filename, 'wb') as _f:
        pickle.dump(storage, _f, pickle.HIGHEST_PROTOCOL)
コード例 #12
0
def train_VAE(args):
    ''' Train VAE '''

    dataset = args.dataset.split('_')[0]
    op = args.dataset.split('_')[1]

    # Load training data
    # data = get_train_data('data/'+dataset+'/x_train_'+dataset+'_'+op+'.npy', 'data/'+dataset+'/y_train_'+dataset+'_'+op+'.npy')
    (X_train, y_train, X_val, y_val) = get_train_data_('data/'+dataset+'/x_train_'+dataset+'_'+op+'.npy',
                                                       'data/'+dataset+'/y_train_'+dataset+'_'+op+'.npy',
                                                       'data/'+dataset+'/x_val_'+dataset+'_'+op+'.npy',
                                                       'data/'+dataset+'/y_val_'+dataset+'_'+op+'.npy')

    vae_obj = VAE(batch_size=1,
                  latent_dim=100,
                  epsilon_std=1.)

    vae, enc = train(data=(X_train, y_train, X_val, y_val),
                     vae_obj=vae_obj,
                     model_path=args.model_path,
                     batch_size=args.batch_size, epochs=args.epochs,
                     out_dir='checkpoint/'+dataset+'_'+op)
コード例 #13
0
def optimize_model(cfg,
                   model_path,
                   vae_path=None,
                   auto_mode=0,
                   env_type='simulator'):
    global ctr

    time_steps = 5000

    vae = None

    if vae_path:
        # init vae
        print('Initializing vae...')
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        vae = VAE(image_channels=cfg.IMAGE_CHANNELS, z_dim=cfg.VARIANTS_SIZE)
        vae.load_state_dict(
            torch.load(vae_path, map_location=torch.device(device)))
        vae.to(device).eval()

    # create agent; wrapper for environment; later we can add vae to the agent
    agent = DonkeyAgent(cam,
                        time_step=0.05,
                        frame_skip=1,
                        env_type=env_type,
                        controller=ctr,
                        vae=vae,
                        auto_mode=auto_mode)
    print('DonkeyAgent created...')

    model = CustomSAC.load(model_path, agent)
    print('Executing model...')

    model.learning_starts = 0

    ctr.mode = 'local'

    model.learn(total_timesteps=time_steps, log_interval=cfg.LOG_INTERVAL)

    model.save(model_path)

    print('Model finished.')
コード例 #14
0
def drive_model(cfg, model_path, vae_path=None, env_type='simulator'):
    global ctr

    time_steps = 10000

    vae = None

    if vae_path:
        # init vae
        print('Initializing vae...')
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        vae = VAE(image_channels=cfg.IMAGE_CHANNELS, z_dim=cfg.VARIANTS_SIZE)
        vae.load_state_dict(
            torch.load(vae_path, map_location=torch.device(device)))
        vae.to(device).eval()

    # create agent; wrapper for environment; later we can add vae to the agent
    agent = DonkeyAgent(cam,
                        time_step=0.05,
                        frame_skip=1,
                        env_type='simulator',
                        controller=ctr,
                        vae=vae)
    print('DonkeyAgent created...')

    model = CustomSAC.load(model_path)
    print('Executing model...')
    obs = agent.reset()

    ctr.mode = 'local'

    for step in range(time_steps):
        if step % 100 == 0: print("step: ", step)
        action, _states = model.predict(obs, deterministic=False)
        obs, rewards, dones, info = agent.step(action)
        time.sleep(0.05)

        while agent.is_game_over():
            print('waiting for control')
            time.sleep(1)
コード例 #15
0
def make_session_name(model_name, agent_name,  iteration, agent):
    return f'{model_name}_{agent_name}_iteration_{iteration}_h{agent.horizon}_g{agent.max_generations}_sb{agent.is_shift_buffer}'


"""
- Define experiment names
- Define what iteration to start from ie 0 is beginning
- Ensure specific test is available in planning tester
- Ensure is_logging is true in config under test_suite to log results
- Run the script
"""
if __name__ == '__main__':
    torch.set_num_threads(1)
    os.environ['OMP_NUM_THREADS'] = str(1)
    frame_preprocessor = Preprocessor(config['preprocessor'])
    vae = VAE(config)
    vae_trainer = VaeTrainer(config, frame_preprocessor)
    vae = vae_trainer.reload_model(vae, device='cpu')
    agent = get_planning_agent()

    experiment_names = ['World_Model_Iter_A']  # TODO DEFINE MODELS TO RETEST (add many for multiple models)
    start_from_iter = 0                        # TODO DEFINE WHAT ITERATION TO START FROM

    for experiment_name in experiment_names:
        mdrnn_models_location = 'mdrnn/checkpoints/backups'
        files = [os.path.join(root, name) for root, dirs, files in os.walk(mdrnn_models_location) for name in files]
        files = [file for file in files if experiment_name in file]
        files.sort(key=lambda path: get_digit_from_path(path))
        print(files)

        iteration_results = {}
コード例 #16
0
class MyEnv:
    def __init__(self, env_):
        self.env = env_
        self.dev = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')
        self.action_space = env_.action_space
        # self.observation_space = (1, 32)

        # vae
        self.vae = VAE()
        model_path = "./vae/vae.pth"
        self.vae.load_state_dict(torch.load(model_path))
        self.vae.to(self.dev)
        self._gen_id = 0  # 何回目のgenerateかを保持
        self._frames = []  # mp4生成用にframeを保存

    def step(self, action, show=False):
        n_state, rew, done, info = self.env.step(action)
        if show:
            self._frames.append(np.array(n_state))
        n_state = self.convert_state_vae(n_state)
        rew = self.change_rew(rew, info)
        if info["cte"] > 3.5:
            done = True
            rew = -1.0
        elif info["cte"] < -5.0:
            done = True
            rew = -1.0
        return n_state, rew, done, info

    def change_rew(self, rew, info):
        if info["speed"] < 0.0:
            return -0.6
        elif abs(info["cte"]) >= 2.0:
            return -1.0
        if rew > 0.0:
            rew /= 20.0
            if info["speed"] > 3.0:
                rew += info["speed"] / 30.0
        return rew

    def reset(self):
        state = self.env.reset()
        state = self.convert_state_vae(state)
        return state

    def seed(self, seed_):
        self.env.seed(seed_)

    def convert_state_to_tensor(
            self, state):  # state(array) -> np.array -> convert some -> tensor
        state_ = np.array(state).reshape((160, 120, 3))
        # print("state_ shape {}".format(state1.shape))
        state_ = state_[0:160, 40:120, :].reshape((1, 80, 160, 3))
        # print("state shape {}".format(state_.shape))
        # state_ = state_.reshape((1, 80, 160, 3))
        state_ = torch.from_numpy(state_).permute(0, 3, 1,
                                                  2).float().to(self.dev)
        state_ /= 255.0
        return state_

    def convert_state_vae(self, state):
        state_ = self.convert_state_to_tensor(state)
        state_, _, _ = self.vae.encode(state_)
        state_ = state_.clone().detach().cpu().numpy()[0]
        return state_

    def generate_mp4(self):
        # for mp4
        fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
        fps = 10
        current_path = os.getcwd()  # 現在のディレクトリ
        video = cv2.VideoWriter(
            current_path + '/mp4/output' + str(self._gen_id) + ".mp4", fourcc,
            fps, (120, 160))
        if not video.isOpened():
            print("can't be opened")
        # for path
        os.mkdir("./tmp")
        current_path = os.getcwd()  # 現在のディレクトリ
        # main procedure
        for idx, frame in enumerate(self._frames):
            fram = Image.fromarray(frame, "RGB")
            path = current_path + "/tmp/frame" + str(idx) + ".png"
            fram.save(path, 'PNG')
            img = cv2.imread(path)
            img = cv2.resize(img, (120, 160))
            if img is None:
                print("can't read")
                break
            video.write(img)
        video.release()
        shutil.rmtree("./tmp")
        self._frames.clear()
        self._gen_id += 1
コード例 #17
0
ファイル: env.py プロジェクト: Ryunosuke-Ikeda/RL_Car0
class MyEnv:
    def __init__(self, env_):
        self.env = env_
        self.dev = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.action_space = env_.action_space
        self._state_steps = 3
        self.observation_space = (80, 160, 3*self._state_steps)
        print("obs shape {}".format(self.observation_space))
        self.state_shape = 32*self._state_steps
        # vae
        self.vae = VAE()
        model_path = "./vae/vae_ikeda_ver1_32.pth"
        self.vae.load_state_dict(torch.load(model_path))
        self.vae.to(self.dev)
        self._gen_id = 0  # 何回目のgenerateかを保持
        self._frames = []  # mp4生成用にframeを保存
        # pre processing
        self._state_frames = deque(maxlen=self._state_steps)  # 変換前のframe
        self._gen_id = 0  # 何回目のgenerateかを保持
        self._frames = []  # mp4生成用にframeを保存
        self._step_repeat_times = 2
        # self.num_envs = 1

    def close(self):
        self.env.close()

    def step(self, action, show=False):
        rews = 0.0
        for i in range(self._step_repeat_times):
            n_state, rew, done, info = self.env.step(action)
            rews += rew
            if i == 0:
                self._state_frames.append(self.adjust_picture(n_state))
            if show:
                self._frames.append(np.array(n_state))
            if done:
                break
        n_state_return = self.convert_state()  # state 生成
        rew = self.change_rew(rews/self._step_repeat_times, info)
        if info["cte"] > 2.5:
            done = True
            rew = -1.0
        elif info["cte"] < -5.0:
            done = True
            rew = -1.0
        return n_state_return, rew, done, info

    def change_rew(self, rew, info):
        if info["speed"] < 0.0:
            return -0.6
        if abs(info["cte"]) < 1.0:
            rew = info["speed"] / 100.0
        else:
            rew = -0.6
        return rew

    def reset(self):
        rand_step = random.randrange(10)
        self.env.reset()
        for _ in range(rand_step + self._state_steps):
            action = self.env.action_space.sample()
            for i in range(self._step_repeat_times):
                n_state, _, _, _ = self.env.step(action)
                if i == 0:
                    self._state_frames.append(self.adjust_picture(n_state))
        state = self.convert_state()
        return state

    def seed(self, seed_):
        self.env.seed(seed_)

    def adjust_picture(self, pict):
        vae_state = self.convert_state_vae(pict)
        return vae_state

    def convert_state(self):
        state_pre = []
        for state in self._state_frames:
            state_pre.append(state)
        state = np.concatenate(state_pre, 0)
        return state

    def convert_state_to_tensor(self, state):  # state(array) -> np.array -> convert some -> tensor
        state_ = np.array(state).reshape((160, 120, 3))
        state_ = state_[0:160, 40:120, :].reshape((1, 80, 160, 3))
        state_ = torch.from_numpy(state_).permute(0, 3, 1, 2).float().to(self.dev)
        return state_

    def convert_state_vae(self, state):
        state_ = self.convert_state_to_tensor(state)
        state_, _, _ = self.vae.encode(state_)
        state_ = state_.clone().detach().cpu().numpy()[0]
        return state_

    def generate_mp4(self):
        # for mp4
        fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
        fps = 10
        video = cv2.VideoWriter('./mp4/output' + str(self._gen_id) + ".mp4", fourcc, fps, (120, 160))
        if not video.isOpened():
            print("can't be opened")
        # for path
        os.mkdir("./tmp")
        current_path = os.getcwd()  # 現在のディレクトリ
        # main procedure
        for idx, frame in enumerate(self._frames):
            fram = Image.fromarray(frame, "RGB")
            path = current_path + "/tmp/frame" + str(idx) + ".png"
            fram.save(path, 'PNG')
            img = cv2.imread(path)
            img = cv2.resize(img, (120, 160))
            if img is None:
                print("can't read")
                break
            video.write(img)
        video.release()
        shutil.rmtree("./tmp")
        self._frames.clear()
        self._gen_id += 1
コード例 #18
0
ファイル: main.py プロジェクト: zivzone/WorldModelPlanning
 def train_or_reload_vae(self):
     vae_trainer = VaeTrainer(self.config, self.frame_preprocessor)
     vae = VAE(self.config)
     return vae_trainer.train(
         vae) if config["is_train_vae"] else vae_trainer.reload_model(vae)
コード例 #19
0
def _load_vae(model_path, variants_size, image_channels, device):
    vae = VAE(image_channels=image_channels, z_dim=variants_size)
    vae.load_state_dict(
        torch.load(model_path, map_location=torch.device(device)))
    vae.to(torch.device(device)).eval()
    return vae