示例#1
0
    def train_gru(self, gru_net, gru_net_path, gru_plot_dir, train_data,
                  batch_size, train_epochs, cuda, bn_episodes,
                  bottleneck_data_path, generate_max_steps, gru_prob_data_path,
                  gru_dir):
        logging.info('Training GRU!')
        start_time = time.time()
        gru_net.train()
        optimizer = optim.Adam(gru_net.parameters(), lr=1e-3)
        gru_net = gru_nn.train(gru_net,
                               self.env,
                               optimizer,
                               gru_net_path,
                               gru_plot_dir,
                               train_data,
                               batch_size,
                               train_epochs,
                               cuda,
                               trunc_k=50)
        logging.info('Generating Data-Set for Later Bottle Neck Training')
        gru_net.eval()
        tl.generate_bottleneck_data(gru_net,
                                    self.env,
                                    bn_episodes,
                                    bottleneck_data_path,
                                    cuda=cuda,
                                    max_steps=generate_max_steps)
        tl.generate_trajectories(self.env, 500, batch_size, gru_prob_data_path,
                                 gru_net.cpu())
        tl.write_net_readme(gru_net,
                            gru_dir,
                            info={'time_taken': time.time() - start_time})

        return gru_net
示例#2
0
        # ***********************************************************************************
        # Generating BottleNeck training data                                               *
        # ***********************************************************************************
        if args.generate_bn_data:
            tl.set_log(data_dir, 'generate_bn_data')
            logging.info('Generating Data-Set for Later Bottle Neck Training')
            gru_net = GRUNet(len(obs), args.gru_size, int(env.action_space.n))
            gru_net.load_state_dict(torch.load(gru_net_path))
            gru_net.noise = False
            if args.cuda:
                gru_net = gru_net.cuda()
            gru_net.eval()
            tl.generate_bottleneck_data(gru_net,
                                        env,
                                        args.bn_episodes,
                                        bottleneck_data_path,
                                        cuda=args.cuda,
                                        eps=(0, 0.3),
                                        max_steps=args.generate_max_steps)
            tl.generate_trajectories(env,
                                     3,
                                     5,
                                     gru_prob_data_path,
                                     gru_net,
                                     cuda=args.cuda,
                                     render=True)

        # ***********************************************************************************
        # HX-QBN                                                                            *
        # ***********************************************************************************
        if args.bhx_train or args.bhx_test: