def test_shape_of_output_of_6d_rnn(self):
        units = 7
        last_dim_size = 12
        rnn = MDRNN(units=units, input_shape=(None, None, None, None, None, None, last_dim_size),
                    return_sequences=True,
                    activation='tanh')

        x = tf.zeros(shape=(2, 3, 1, 2, 2, 1, 5, last_dim_size))

        result = rnn.call(x)
        self.assertEqual((2, 3, 1, 2, 2, 1, 5, units), result.shape)
Exemplo n.º 2
0
def ACO(aco_iterations, n_in, n_out, n_hidden, pheromones):
    """
    Run the ACO algorithm for aco_iterations iterations.
    """
    # Data parameters
    n_inputs = 2
    n_outputs = 1
    n_hiddens = 2  # should equal n_inputs

    # Hyper parameters
    n_gaussians = 1  # is this dependent on other factors?
    n_models = 5  # 1 for now, actually 10
    deg_freq = 5
    batch_size = 16

    # Initialize population
    population = []
    cur_best = 100000
    training_data, test_data = load_data()
    for iteration in tqdm.tqdm(range(aco_iterations)):
        # Generate paths for the models

        for _ in range(n_models):
            model = MDRNN(n_inputs, n_outputs, n_hiddens, n_gaussians)
            paths = Paths(n_inputs, n_hiddens, pheromones)
            # Prune the model
            # Loop layers
            prune_layer(model, paths, n_inputs, n_hiddens)

            # Training loop
            train(model, training_data, batch_size)

            # Update fitness
            test(model, test_data, batch_size)

            # Add model to population
            population.append((model, paths))

        # Update pheromones
        for model, paths in population:
            if model.fitness < cur_best:  # Reward
                print("Great")
                torch.save(model.state_dict(), 'model_weights.pth')
                cur_best = model.fitness
                pheromones.update(paths, 0)
            else:  # Punishment
                pheromones.update(paths, 1)
            if iteration % deg_freq == 0:  # Decay step
                pheromones.update(paths, 2)
    bests = []
    for model, paths in population:
        if model.fitness == cur_best:
            print(model.fitness, model)
    def test_shape(self):
        rnn2d = MDRNN(units=5, input_shape=(None, None, 1),
                      kernel_initializer=initializers.Constant(1),
                      recurrent_initializer=initializers.Constant(1),
                      bias_initializer=initializers.Constant(-1),
                      return_sequences=True,
                      activation='tanh')

        x = np.arange(6).reshape((1, 2, 3, 1))

        res = rnn2d.call(x)

        self.assertEqual((1, 2, 3, 5), res.shape)
    def test_for_two_step_sequence(self):
        kernel_initializer = initializers.Zeros()
        recurrent_initializer = kernel_initializer

        mdrnn = MDRNN(units=2,
                      input_shape=(None, 1),
                      kernel_initializer=kernel_initializer,
                      recurrent_initializer=recurrent_initializer,
                      bias_initializer=initializers.Constant(5),
                      return_sequences=True)
        x = np.zeros((1, 3, 1))
        a = mdrnn.call(x)

        expected_result = np.ones((1, 3, 2)) * 0.9999
        np.testing.assert_almost_equal(expected_result, a.numpy(), 4)
 def test_feed_uni_directional(self):
     rnn = MDRNN(units=16,
                 input_shape=(5, 4, 10),
                 activation='tanh',
                 return_sequences=True)
     output = rnn(np.zeros((1, 5, 4, 10)))
     self.assertEqual((1, 5, 4, 16), output.shape)
    def test_output_sequences_match(self):
        self.kwargs.update(dict(return_sequences=True))
        rnn = MDRNN(**self.kwargs)
        keras_rnn = tf.keras.layers.SimpleRNN(**self.kwargs)

        x = tf.constant(np.random.rand(3, 4, 5), dtype=tf.float32)

        np.testing.assert_almost_equal(rnn(x).numpy(), keras_rnn(x).numpy(), 6)
    def make_rnn(self, return_sequences, return_state):
        shape = self.x.shape[1:]
        rnn = MDRNN(units=self.units,
                    input_shape=shape,
                    return_sequences=return_sequences,
                    return_state=return_state,
                    activation='tanh')

        return MultiDirectional(rnn)
    def test_2drnn_output_when_providing_initial_state(self):
        rnn2d = MDRNN(units=1, input_shape=(None, None, 1),
                      kernel_initializer=initializers.Identity(),
                      recurrent_initializer=initializers.Identity(1),
                      bias_initializer=initializers.Constant(-1),
                      return_sequences=True,
                      activation=None)

        x = np.arange(6).reshape((1, 2, 3, 1))

        initial_state = [tf.ones(shape=(1, 1)), tf.ones(shape=(1, 1))]

        actual = rnn2d.call(x, initial_state=initial_state)
        desired = np.array([
            [1, 1, 2],
            [3, 7, 13]
        ]).reshape((1, 2, 3, 1))
        np.testing.assert_almost_equal(desired, actual.numpy(), 6)
    def test_result(self):
        rnn2d = MDRNN(units=1, input_shape=(None, None, 1),
                      kernel_initializer=initializers.Identity(),
                      recurrent_initializer=initializers.Identity(1),
                      bias_initializer=initializers.Constant(-1),
                      return_sequences=True,
                      activation=None)

        x = np.arange(6).reshape((1, 2, 3, 1))

        actual = rnn2d.call(x)

        desired = np.array([
            [-1, -1, 0],
            [1, 3, 7]
        ]).reshape((1, 2, 3, 1))

        np.testing.assert_almost_equal(desired, actual.numpy(), 6)
Exemplo n.º 10
0
    def create_model(self, input_dim, units, seq_len, return_sequences,
                     output_units=1, output_activation=None):
        model = tf.keras.Sequential()
        model.add(MDRNN(units=units,
                        input_shape=[seq_len, input_dim],
                        return_sequences=return_sequences))

        model.add(tf.keras.layers.Dense(units=output_units, activation=output_activation))
        model.compile(optimizer='sgd', loss='mean_squared_error', metrics=[])
        return model
    def make_rnn(self):
        kwargs = dict(units=1,
                      input_shape=(None, None, 1),
                      kernel_initializer=initializers.Identity(),
                      recurrent_initializer=initializers.Identity(1),
                      bias_initializer=initializers.Constant(-1),
                      return_sequences=True,
                      return_state=True,
                      direction=self._direction,
                      activation=None)

        return MDRNN(**kwargs)
    def create_mdrnn(self, direction):
        kernel_initializer = initializers.zeros()
        recurrent_initializer = initializers.constant(2)
        bias_initializer = initializers.zeros()

        return MDRNN(units=1, input_shape=(None, 1),
                     kernel_initializer=kernel_initializer,
                     recurrent_initializer=recurrent_initializer,
                     bias_initializer=bias_initializer,
                     activation=None,
                     return_sequences=True,
                     direction=direction)
 def test_fit_uni_directional(self):
     model = tf.keras.Sequential()
     model.add(MDRNN(units=16, input_shape=(2, 3, 6), activation='tanh'))
     model.add(tf.keras.layers.Dense(units=10, activation='softmax'))
     model.compile(loss='categorical_crossentropy', metrics=['acc'])
     model.summary()
     x = np.zeros((10, 2, 3, 6))
     y = np.zeros((
         10,
         10,
     ))
     model.fit(x, y)
Exemplo n.º 14
0
    def test_result_after_running_rnn_on_3d_input(self):
        rnn3d = MDRNN(units=1,
                      input_shape=(None, None, None, 1),
                      kernel_initializer=initializers.Identity(),
                      recurrent_initializer=initializers.Identity(1),
                      bias_initializer=initializers.Constant(1),
                      return_sequences=True,
                      return_state=True,
                      activation=None)

        x = np.arange(2 * 2 * 2).reshape((1, 2, 2, 2, 1))

        outputs, state = rnn3d.call(x)

        desired = np.array([[[1, 3], [4, 11]], [[6, 15], [17, 51]]]).reshape(
            (1, 2, 2, 2, 1))

        np.testing.assert_almost_equal(desired, outputs.numpy(), 6)

        desired_state = desired[:, -1, -1, -1]
        np.testing.assert_almost_equal(desired_state, state.numpy(), 6)
    def test_1d_rnn_produces_correct_output_for_2_steps(self):
        kernel_initializer = initializers.identity()
        recurrent_initializer = kernel_initializer

        bias = 3
        bias_initializer = initializers.Constant(bias)

        mdrnn = MDRNN(units=3,
                      input_shape=(None, 3),
                      kernel_initializer=kernel_initializer,
                      recurrent_initializer=recurrent_initializer,
                      bias_initializer=bias_initializer,
                      activation=None,
                      return_sequences=True)

        x1 = np.array([1, 2, 4])
        x2 = np.array([9, 8, 6])
        x = np.array([x1, x2])
        a = mdrnn.call(x.reshape((1, 2, -1)))

        expected_result = np.array([[x1 + bias, x1 + x2 + 2 * bias]])
        np.testing.assert_almost_equal(expected_result, a.numpy(), 8)
    def make_rnns(self, return_sequences, return_state):
        seed = 1
        kwargs = dict(units=3, input_shape=(None, 5),
                      kernel_initializer=initializers.glorot_uniform(seed),
                      recurrent_initializer=initializers.he_normal(seed),
                      bias_initializer=initializers.Constant(2),
                      return_sequences=return_sequences,
                      return_state=return_state,
                      activation='relu'
                      )
        rnn = MultiDirectional(MDRNN(**kwargs))

        keras_rnn = tf.keras.layers.Bidirectional(tf.keras.layers.SimpleRNN(**kwargs))
        return rnn, keras_rnn
    def test_feeding_5_dimensional_rnn_returns_sequences_and_last_state(self):
        rnn = MDRNN(units=3,
                    input_shape=(None, None, None, None, None, 1),
                    return_sequences=True,
                    return_state=True)

        shape = (2, 2, 3, 1, 2, 6, 1)
        x = np.arange(2 * 2 * 3 * 1 * 2 * 6).reshape(shape)
        res = rnn(x)
        self.assertEqual(2, len(res))
        sequences, state = res

        expected_sequence_shape = (2, 2, 3, 1, 2, 6, 3)
        expected_state_shape = (2, 3)
        self.assertEqual(expected_sequence_shape, sequences.shape)
        self.assertEqual(expected_state_shape, state.shape)
    def test_fit_multi_directional(self):
        x = np.zeros((10, 2, 3, 6))
        y = np.zeros((
            10,
            40,
        ))

        model = tf.keras.Sequential()
        model.add(tf.keras.layers.Input(shape=(2, 3, 6)))
        model.add(MultiDirectional(MDRNN(10, input_shape=[2, 3, 6])))

        model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.001,
                                                         clipnorm=100),
                      loss='categorical_crossentropy',
                      metrics=['acc'])
        model.summary()

        model.fit(x, y, epochs=1)
Exemplo n.º 19
0
    def create_model(self, input_dim, units, seq_len, return_sequences,
                     output_units=1, output_activation=None):

        inp = tf.keras.layers.Input(shape=[seq_len, input_dim])

        x = inp

        mdrnn = MDRNN(units=units, input_shape=(None, input_dim),
                      return_sequences=return_sequences)

        densor = tf.keras.layers.Dense(units=output_units, activation=output_activation)

        x = mdrnn(x)

        output = densor(x)

        model = tf.keras.Model(inputs=inp, outputs=output)
        model.compile(optimizer='sgd', loss='mean_squared_error', metrics=[])
        return model
    def make_rnns(self, return_sequences, return_state, go_backwards=False):
        seed = 1
        kwargs = dict(units=3, input_shape=(None, 5),
                      kernel_initializer=initializers.glorot_uniform(seed),
                      recurrent_initializer=initializers.he_normal(seed),
                      bias_initializer=initializers.Constant(2),
                      return_sequences=return_sequences,
                      return_state=return_state,
                      activation='relu'
                      )

        mdrnn_kwargs = dict(kwargs)
        if go_backwards:
            mdrnn_kwargs.update(dict(direction=Direction(-1)))
        rnn = MDRNN(**mdrnn_kwargs)

        keras_rnn_kwargs = dict(kwargs)
        keras_rnn_kwargs.update(dict(go_backwards=go_backwards))
        keras_rnn = tf.keras.layers.SimpleRNN(**keras_rnn_kwargs)
        return rnn, keras_rnn
def fit_mdrnn(target_image_size=10, rnn_units=128, epochs=30, batch_size=32):
    # get MNIST examples
    (x_train, y_train), (x_test, y_test) = mnist.load_data()

    # down sample images to speed up the training and graph building process for mdrnn
    x_train = down_sample(x_train, target_image_size)
    x_test = down_sample(x_test, target_image_size)

    inp = tf.keras.layers.Input(shape=(target_image_size, target_image_size,
                                       1))

    # create multi-directional MDRNN layer
    rnn = MultiDirectional(
        MDRNN(units=rnn_units,
              input_shape=[target_image_size, target_image_size, 1]))

    dense = tf.keras.layers.Dense(units=10, activation='softmax')

    # build a model
    x = inp
    x = rnn(x)
    outputs = dense(x)
    model = tf.keras.Model(inp, outputs)

    # choose Adam optimizer, set gradient clipping to prevent gradient explosion,
    # set a categorical cross-entropy loss function
    model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.001, clipnorm=100),
                  loss='categorical_crossentropy',
                  metrics=['acc'])
    model.summary()

    # fit the model
    model.fit(x_train,
              tf.keras.utils.to_categorical(y_train),
              epochs=epochs,
              validation_data=(x_test, tf.keras.utils.to_categorical(y_test)),
              batch_size=batch_size)
Exemplo n.º 22
0
 def create_mdrnn(self, direction):
     input_dimensions = [None] * self.ndims
     shape = tuple(input_dimensions + [self.input_dim])
     return MDRNN(units=self.units, input_shape=shape, direction=direction)
Exemplo n.º 23
0
 def make_rnn(self, **kwargs):
     return MDRNN(**kwargs)
Exemplo n.º 24
0
def ACO(aco_iterations, n_inputs, n_outputs, n_hiddens, pheromones, 
        training_data, test_data, n_gaussians):
    """
    Run the ACO algorithm for aco_iterations iterations.
    """
    # Store data:
    avg_train_losses = []
    avg_test_losses = []
    
    # Hyper parameters
    n_models = 10
    deg_freq = 5
    batch_size = 16
    n_epochs = 10
    lr = 0.01
    
    # Initialize population
    population = []
    cur_best = math.inf
    for iteration in tqdm.tqdm(range(aco_iterations)):
        # Generate paths for the models
        
        train_losses = []
        test_losses = []
        
        for n in range(n_models):
            
            # Define model and paths
            model = MDRNN(n_inputs, n_outputs, n_hiddens, n_gaussians).float()
            paths = Paths(n_inputs, n_hiddens, pheromones)
            
            # Prune the model
            prune_layers(model, paths, n_inputs, n_hiddens)
         
            # Training loop 
            loss = train(model, training_data, n_epochs, batch_size, lr)
            print("\nTrain loss for model {}: {:.4f}".format(n+1, loss.item()))
            
            # Update fitness
            model.fitness = test(model, test_data, batch_size)
            print("Test loss for model {}: {}".format(n+1, model.fitness))
            
            # Save losses:
            train_losses.append(loss.item())
            test_losses.append(model.fitness)
            
            # Add model to population
            population.append((model, paths))
        
        # Update pheromones
        for model, paths in population:
            if model.fitness < cur_best:  # Reward
                torch.save(model.state_dict(), 'model_weights.pth')
                cur_best = model.fitness
                pheromones.update(paths, 0)
            else:  # Punishment 
                pheromones.update(paths, 1)
            if iteration % deg_freq == 0:  # Decay step
                pheromones.update(paths, 2)
        
        avg_train_losses.append(np.average(train_losses))
        avg_test_losses.append(np.average(test_losses))
    
    return avg_train_losses, avg_test_losses
 def create_mdrnn(self, **kwargs):
     return MDRNN(**kwargs)
Exemplo n.º 26
0
def main(cfg):
    ################
    # Constants
    ################

    epochs = cfg.epochs
    sequence_horizon = cfg.sequence_horizon
    batch_size = cfg.batch_size

    rnn_hidden_size = cfg.rnn_hidden_size
    latent_size = cfg.latent_size
    vae_hidden_size = cfg.vae_hidden_size
    num_gaussians = cfg.num_gaussians

    ##################
    # Loading datasets
    ##################
    # Data Loading
    dataset = DynamicsDataset(cfg.dataset.train_path, horizon=sequence_horizon)
    collate_fn = get_standardized_collate_fn(dataset, keep_dims=True)
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn)
    action_dims = len(dataset.actions[0][0])
    state_dims = len(dataset.states[0][0])

    # load test dataset
    val_dataset = DynamicsDataset(cfg.dataset.test_path, horizon=sequence_horizon)
    collate_fn = get_standardized_collate_fn(val_dataset, keep_dims=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=True, collate_fn=collate_fn)

    # Only do 1-step predictions
    test_dataset = DynamicsDataset(cfg.dataset.test_path, horizon=1)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

    input_size = state_dims
    action_size = action_dims

    # Loading VAE
    vae_file = join(cfg.logdir, 'vae', 'best.tar')
    assert exists(vae_file), "No trained VAE in the logdir..."
    state = torch.load(vae_file)
    print("Loading VAE at epoch {} "
        "with test error {}".format(
            state['epoch'], state['precision']))

    vae = VAE(input_size, latent_size, vae_hidden_size).to(device)
    vae.load_state_dict(state['state_dict'])

    # Loading model
    rnn_dir = join(cfg.logdir, 'mdrnn')
    rnn_file = join(rnn_dir, 'best.tar')

    if not exists(rnn_dir):
        mkdir(rnn_dir)

    mdrnn = MDRNN(latent_size, action_size, rnn_hidden_size, num_gaussians)

    mdrnn.to(device)
    optimizer = torch.optim.Adam(mdrnn.parameters(), lr=1e-3)
    scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)
    earlystopping = EarlyStopping('min', patience=30)

    if exists(rnn_file) and not cfg.noreload:
        rnn_state = torch.load(rnn_file)
        print("Loading MDRNN at epoch {} "
            "with test error {}".format(
                rnn_state["epoch"], rnn_state["precision"]))
        mdrnn.load_state_dict(rnn_state["state_dict"])
        optimizer.load_state_dict(rnn_state["optimizer"])
        scheduler.load_state_dict(state['scheduler'])
        earlystopping.load_state_dict(state['earlystopping'])

    def save_checkpoint(state, is_best, filename, best_filename):
        """ Save state in filename. Also save in best_filename if is_best. """
        torch.save(state, filename)
        if is_best:
            torch.save(state, best_filename)

    def to_latent(obs, next_obs, batch_size=1, sequence_horizon=1):
        """ Transform observations to latent space.

        :args obs: 5D torch tensor (batch_size, SEQ_LEN, ASIZE, SIZE, SIZE)
        :args next_obs: 5D torch tenbayersor (BSIZE, SEQ_LEN, ASIZE, SIZE, SIZE)

        :returns: (latent_obs, latent_next_obs)
            - latent_obs: 4D torch tensor (BSIZE, SEQ_LEN, LSIZE)
            - next_latent_obs: 4D torch tensor (BSIZE, SEQ_LEN, LSIZE)
        """
        with torch.no_grad():

            # 1: is to ignore the reconstruction, just get mu and logsigma
            (obs_mu, obs_logsigma), (next_obs_mu, next_obs_logsigma) = [
                vae(x)[1:] for x in (obs, next_obs)]
            latent_obs, latent_next_obs = [
                (x_mu + x_logsigma.exp() * torch.randn_like(x_mu)).view(batch_size,
                    sequence_horizon, latent_size)
                for x_mu, x_logsigma in
                [(obs_mu, obs_logsigma), (next_obs_mu, next_obs_logsigma)]]
        return latent_obs, latent_next_obs

    def get_loss(latent_obs, action, reward, terminal,
                latent_next_obs, include_reward: bool):
        """ Compute losses.

        The loss that is computed is:
        (GMMLoss(latent_next_obs, GMMPredicted) + MSE(reward, predicted_reward) +
            BCE(terminal, logit_terminal)) / (LSIZE + 2)
        The LSIZE + 2 factor is here to counteract the fact that the GMMLoss scales
        approximately linearily with LSIZE. All losses are averaged both on the
        batch and the sequence dimensions (the two first dimensions).

        :args latent_obs: (BSIZE, SEQ_LEN, LSIZE) torch tensor
        :args action: (BSIZE, SEQ_LEN, ASIZE) torch tensor
        :args reward: (BSIZE, SEQ_LEN) torch tensor
        :args latent_next_obs: (BSIZE, SEQ_LEN, LSIZE) torch tensor

        :returns: dictionary of losses, containing the gmm, the mse, the bce and
            the averaged loss.
        """
        latent_obs, action,\
            reward, terminal,\
            latent_next_obs = [arr.transpose(1, 0)
                            for arr in [latent_obs, action,
                                        reward, terminal,
                                        latent_next_obs]]

        mus, sigmas, logpi,pi, rs, ds = mdrnn(action, latent_obs)
        gmm = gmm_loss(latent_next_obs, mus, sigmas, logpi)
        # bce = f.binary_cross_entropy_with_logits(ds, terminal)
        bce = 0
        if include_reward:
            mse = f.mse_loss(rs, reward)
            scale = latent_size + 2
        else:
            mse = 0
            scale = latent_size + 1
        # loss = (gmm + bce + mse) / scale
        loss = (gmm + mse) / scale
        return dict(gmm=gmm, bce=bce, mse=mse, loss=loss)


    def data_pass(epoch, train, include_reward):
        """ One pass through the data """
        if train:
            mdrnn.train()
            loader = train_loader
        else:
            mdrnn.eval()
            loader = val_loader

        cum_loss = 0
        cum_gmm = 0
        cum_bce = 0
        cum_mse = 0

        for batch_index, (states, actions, next_states, rewards, _, _ )in enumerate(loader):
            #import pdb;pdb.set_trace()
            if batch_index > 1000:
                break
            states = states.to(device)
            next_states = next_states.to(device)
            rewards = rewards.to(device)
            actions = actions.to(device)
            # Not sure why terminals matter here but we dont store them in the dataset
            # so just set them all to false.
            terminal = torch.zeros(batch_size, sequence_horizon).to(device)

            latent_obs, latent_next_obs = to_latent(states, next_states,
                batch_size=batch_size, sequence_horizon=sequence_horizon)
        
            if train:
                losses = get_loss(latent_obs, actions, rewards,
                                terminal, latent_next_obs, include_reward)

                optimizer.zero_grad()
                losses['loss'].backward()
                optimizer.step()
            else:
                with torch.no_grad():
                    losses = get_loss(latent_obs, actions, rewards,
                                    terminal, latent_next_obs,include_reward)

            cum_loss += losses['loss'].item()
            cum_gmm += losses['gmm'].item()
            # cum_bce += losses['bce'].item()
            cum_mse += losses['mse'].item() if hasattr(losses['mse'], 'item') else \
                losses['mse']
            data_size = len(loader)

            if batch_index % 100 == 0:
                print("Train" if train else "Test")

                print("loss={loss:10.6f} bce={bce:10.6f} "
                                    "gmm={gmm:10.6f} mse={mse:10.6f}".format(
                                        loss=cum_loss / data_size, bce=cum_bce / data_size,
                                        gmm=cum_gmm / latent_size / data_size, mse=cum_mse / data_size))

        return cum_loss * batch_size / len(loader.dataset)

    
    def predict():
        mdrnn.eval()

        preds = []
        gt = []
        n_episodes = test_dataset[-1][-2] + 1
        predictions = [[] for _ in range(n_episodes)]
        with torch.no_grad():
            for batch_index, (states, actions, next_states, rewards, episode, timesteps) in enumerate(test_loader):

                states = states.to(device)
                next_states = next_states.to(device)
                rewards = rewards.to(device)
                actions = actions.to(device)

                latent_obs, _ = to_latent(states,
                    next_states, batch_size=1,sequence_horizon=1)

                # Check model's next state predictions
                mus, sigmas, logpi, _ , _, _ = mdrnn(actions, latent_obs)
                mix = D.Categorical(logpi)
                comp = D.Independent(D.Normal(mus, sigmas), 1)
                gmm = D.MixtureSameFamily(mix, comp)
                sample = gmm.sample()

                decoded_states = vae.decoder(sample).squeeze(0)
                decoded_states = decoded_states.cpu().detach().numpy()
                preds.append(decoded_states)

                for i in range(len(states)):
                    predictions[episode[i].int()].append(np.expand_dims(decoded_states[i], axis=0))


                gt.append(next_states.cpu().detach().numpy())
            #import pdb;pdb.set_trace()
            predictions = [np.stack(p) for p in predictions]
            preds = np.asarray(preds)
            gt = np.asarray(gt).squeeze(1)
            error = (preds - gt)**2

        path = cfg.logdir + '/' + cfg.resname + '.pkl'
        pickle.dump(predictions, open(path, 'wb'))

        print("Mean Error: {}".format(error.mean(0)[0]))
        print("Min  Error: {}".format(error.min(0)[0]))
        print("Max  Error: {}".format(error.max(0)[0]))

    train = partial(data_pass, train=True, include_reward=cfg.include_reward)
    test = partial(data_pass, train=False, include_reward=cfg.include_reward)

    cur_best = None
    for e in range(epochs):
        train(e)
        test_loss = test(e)
        predict()

        scheduler.step(test_loss)
        #earlystopping.step(test_loss)

        is_best = not cur_best or test_loss < cur_best
        if is_best:
            cur_best = test_loss
        checkpoint_fname = join(rnn_dir, 'checkpoint.tar')
        save_checkpoint({
            "state_dict": mdrnn.state_dict(),
            "optimizer": optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'earlystopping': earlystopping.state_dict(),
            "precision": test_loss,
            "epoch": e}, is_best, checkpoint_fname,
                        rnn_file)
 def create_default_mdrnn(self, **kwargs):
     return MDRNN(units=3, input_shape=(None, 3), **kwargs)
 def test_feeding_layer_created_with_default_initializer(self):
     mdrnn = MDRNN(units=2, input_shape=(None, 1))
     x = np.zeros((1, 1, 1)) * 0.5
     mdrnn.call(x)