Exemple #1
0
 def test_rollout_data(self):
     """ Test rollout sequence dataset """
     transform = transforms.Lambda(lambda x: np.transpose(x, (0, 3, 1, 2)))
     dataset = RolloutSequenceDataset('datasets/carracing', 32, transform)
     loader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True,
                                          num_workers=8)
     dataset.load_next_buffer()
     init_time = time.time()
     for i, data in enumerate(loader):
         if i == 150:
             self.assertEqual(data[0].size(), torch.Size([8, 32, 3, 96, 96]))
             break
     end_time = time.time()
     print("WallTime: {}s".format(end_time - init_time))
Exemple #2
0
def plot_rollout():
    """ Plot a rollout """
    from torch.utils.data import DataLoader
    from data.loaders import RolloutSequenceDataset
    dataloader = DataLoader(RolloutSequenceDataset(root='datasets/carracing',
                                                   seq_len=900,
                                                   transform=lambda x: x,
                                                   buffer_size=10,
                                                   train=False),
                            batch_size=1,
                            shuffle=True)

    dataloader.dataset.load_next_buffer()

    # setting up subplots
    plt.subplot(2, 2, 1)
    monitor_obs = plt.imshow(np.zeros((64, 64, 3)))
    plt.subplot(2, 2, 2)
    monitor_next_obs = plt.imshow(np.zeros((64, 64, 3)))
    plt.subplot(2, 2, 3)
    monitor_diff = plt.imshow(np.zeros((64, 64, 3)))

    for data in dataloader:
        obs_seq = data[0].numpy().squeeze()
        action_seq = data[1].numpy().squeeze()
        next_obs_seq = data[-1].numpy().squeeze()
        for obs, action, next_obs in zip(obs_seq, action_seq, next_obs_seq):
            monitor_obs.set_data(obs)
            monitor_next_obs.set_data(next_obs)
            monitor_diff.set_data(next_obs - obs)
            print(action)
            plt.pause(.01)
        break
 def test_rollout_data(self):
     """ Test rollout sequence dataset """
     transform = transforms.Lambda(lambda x: np.transpose(x, (0, 3, 1, 2)))
     dataset = RolloutSequenceDataset('datasets/carracing', 32, transform)
     loader = torch.utils.data.DataLoader(dataset,
                                          batch_size=8,
                                          shuffle=True,
                                          num_workers=8)
     dataset.load_next_buffer()
     init_time = time.time()
     for i, data in enumerate(loader):
         if i == 150:
             self.assertEqual(data[0].size(), torch.Size([8, 32, 3, 96,
                                                          96]))
             break
     end_time = time.time()
     print("WallTime: {}s".format(end_time - init_time))
def make_mdrnn_dataset(rollout_dir):
    step_log('3-1. make_mdrnn_dataset START!!')
    transform = transforms.Lambda(lambda x: np.transpose(x,
                                                         (0, 3, 1, 2)) / 255)

    dataset_train = RolloutSequenceDataset(rollout_dir,
                                           M_SEQ_LEN,
                                           transform,
                                           train=True,
                                           buffer_size=30)
    dataset_test = RolloutSequenceDataset(rollout_dir,
                                          M_SEQ_LEN,
                                          transform,
                                          train=False,
                                          buffer_size=10)

    return dataset_train, dataset_test
def plot_rollout():
    """ Plot a rollout """
    from torch.utils.data import DataLoader
    from data.loaders import RolloutSequenceDataset
    dataloader = DataLoader(RolloutSequenceDataset(root=os.path.join(
        args.datasets, 'carracing'),
                                                   seq_len=900,
                                                   transform=lambda x: x,
                                                   buffer_size=10,
                                                   train=False),
                            batch_size=1,
                            shuffle=True)

    dataloader.dataset.load_next_buffer()

    # setting up subplots
    plt.subplot(1, 3, 1)
    monitor_obs = plt.imshow(np.zeros((64, 64, 3)))
    plt.subplot(1, 3, 2)
    monitor_rec_obs = plt.imshow(np.zeros((64, 64, 3)))
    plt.subplot(1, 3, 3)
    monitor_rec_obs_two = plt.imshow(np.zeros((64, 64, 3)))

    for i, data in enumerate(dataloader):
        if i != args.example_num:
            continue
        obs_seq = data[0].numpy().squeeze()
        action_seq = data[1].numpy().squeeze()
        for obs, action in zip(obs_seq, action_seq):
            monitor_obs.set_data(obs.astype(np.uint8))
            with torch.no_grad():
                monitor_rec_obs.set_data(
                    np.transpose(vae(transform(obs))[0].squeeze(), (1, 2, 0)))
                monitor_rec_obs_two.set_data(
                    np.transpose(
                        vae_two(transform(obs))[0].squeeze(), (1, 2, 0)))
            print(action)
            plt.pause(.01)
        break
Exemple #6
0
if exists(rnn_file) and not args.noreload:
    rnn_state = torch.load(rnn_file)
    print("Loading MDRNN at epoch {} "
          "with test error {}".format(
              rnn_state["epoch"], rnn_state["precision"]))
    mdrnn.load_state_dict(rnn_state["state_dict"])
    optimizer.load_state_dict(rnn_state["optimizer"])
    scheduler.load_state_dict(state['scheduler'])
    earlystopping.load_state_dict(state['earlystopping'])


# Data Loading
transform = transforms.Lambda(
    lambda x: np.transpose(x, (0, 3, 1, 2)) / 255)
train_loader = DataLoader(
    RolloutSequenceDataset('datasets/carracing', SEQ_LEN, transform, buffer_size=30),
    batch_size=BSIZE, num_workers=8, shuffle=True)
test_loader = DataLoader(
    RolloutSequenceDataset('datasets/carracing', SEQ_LEN, transform, train=False, buffer_size=10),
    batch_size=BSIZE, num_workers=8)

def to_latent(obs, next_obs):
    """ Transform observations to latent space.

    :args obs: 5D torch tensor (BSIZE, SEQ_LEN, ASIZE, SIZE, SIZE)
    :args next_obs: 5D torch tensor (BSIZE, SEQ_LEN, ASIZE, SIZE, SIZE)

    :returns: (latent_obs, latent_next_obs)
        - latent_obs: 4D torch tensor (BSIZE, SEQ_LEN, LSIZE)
        - next_latent_obs: 4D torch tensor (BSIZE, SEQ_LEN, LSIZE)
    """
earlystopping = EarlyStopping('min', patience=30)

if exists(rnn_file) and not args.noreload:
    rnn_state = torch.load(rnn_file)
    print("Loading MDRNN at epoch {} "
          "with test error {}".format(rnn_state["epoch"],
                                      rnn_state["precision"]))
    mdrnn.load_state_dict(rnn_state["state_dict"])
    optimizer.load_state_dict(rnn_state["optimizer"])
    scheduler.load_state_dict(state['scheduler'])
    earlystopping.load_state_dict(state['earlystopping'])

# Data Loading
transform = transforms.Lambda(lambda x: np.transpose(x, (0, 3, 1, 2)) / 255)
train_loader = DataLoader(RolloutSequenceDataset('datasets/pacman',
                                                 SEQ_LEN,
                                                 transform,
                                                 buffer_size=30),
                          batch_size=BSIZE,
                          num_workers=8,
                          shuffle=True)
test_loader = DataLoader(RolloutSequenceDataset('datasets/pacman',
                                                SEQ_LEN,
                                                transform,
                                                train=False,
                                                buffer_size=10),
                         batch_size=BSIZE,
                         num_workers=8)


def to_latent(obs, next_obs):
    """ Transform observations to latent space.
print('CUDA: {}'.format(cuda))
torch.manual_seed(123)
# Fix numeric divergence due to bug in Cudnn
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if cuda else "cpu")

# constants
BSIZE = 16
SEQ_LEN = args.sequence_length
epochs = 200

# Data Loading
transform = transforms.Lambda(lambda x: np.transpose(x, (0, 3, 1, 2)) / 255)
train_loader = DataLoader(RolloutSequenceDataset(join(args.datasets,
                                                      'carracing'),
                                                 SEQ_LEN,
                                                 transform,
                                                 buffer_size=30),
                          batch_size=BSIZE,
                          num_workers=8,
                          shuffle=True)
test_loader = DataLoader(RolloutSequenceDataset(join(args.datasets,
                                                     'carracing'),
                                                SEQ_LEN,
                                                transform,
                                                train=False,
                                                buffer_size=10),
                         batch_size=BSIZE,
                         num_workers=8)

# Load VAE
          "with test error {}".format(rnn_state["epoch"],
                                      rnn_state["precision"]))
    print("MDRNN loaded from {}".format(prev_rnn_file))
    mdrnn.load_state_dict(rnn_state["state_dict"])
    optimizer.load_state_dict(rnn_state["optimizer"])
    scheduler.load_state_dict(state["scheduler"])
    earlystopping.load_state_dict(state["earlystopping"])

# Data Loading
dataset_dir = args.dataset_dir
if args.iteration_num is not None:
    dataset_dir = join(dataset_dir, "iter_{}".format(args.iteration_num))

transform = transforms.Lambda(lambda x: np.transpose(x, (0, 3, 1, 2)) / 255)
train_loader = DataLoader(
    RolloutSequenceDataset(dataset_dir, SEQ_LEN, transform, buffer_size=30),
    batch_size=BSIZE,
    num_workers=8,
    shuffle=True,
)
test_loader = DataLoader(
    RolloutSequenceDataset(dataset_dir,
                           SEQ_LEN,
                           transform,
                           train=False,
                           buffer_size=10),
    batch_size=BSIZE,
    num_workers=8,
)

Exemple #10
0
if exists(rnn_file) and not args.noreload:
    rnn_state = torch.load(rnn_file)
    print("Loading MDRNN at epoch {} "
          "with val error {}".format(rnn_state["epoch"],
                                     rnn_state["precision"]))
    mdrnn.load_state_dict(rnn_state["state_dict"])
    optimizer.load_state_dict(rnn_state["optimizer"])
    scheduler.load_state_dict(state['scheduler'])
    earlystopping.load_state_dict(state['earlystopping'])

# Data Loading
transform = transforms.Lambda(lambda x: np.transpose(x, (0, 3, 1, 2)) / 255)
train_loader = DataLoader(RolloutSequenceDataset(
    'datasets/carnav',
    args.seq_len,
    transform,
    num_val_rollouts=args.num_val_rollouts),
                          batch_size=args.batch_size,
                          num_workers=8,
                          shuffle=True,
                          drop_last=True)
val_loader = DataLoader(RolloutSequenceDataset(
    'datasets/carnav',
    args.seq_len,
    transform,
    train=False,
    num_val_rollouts=args.num_val_rollouts),
                        batch_size=args.batch_size,
                        num_workers=8,
                        drop_last=True)
Exemple #11
0
earlystopping = EarlyStopping('min', patience=30)

if exists(rnn_file) and not args.noreload:
    rnn_state = torch.load(rnn_file)
    print("Loading MDRNN at epoch {} "
          "with test error {}".format(rnn_state["epoch"],
                                      rnn_state["precision"]))
    mdrnn.load_state_dict(rnn_state["state_dict"])
    optimizer.load_state_dict(rnn_state["optimizer"])
    scheduler.load_state_dict(state['scheduler'])
    earlystopping.load_state_dict(state['earlystopping'])

# Data Loading
transform = transforms.Lambda(lambda x: np.transpose(x, (0, 3, 1, 2)) / 255)
train_loader = DataLoader(RolloutSequenceDataset('D:\steps1000\datasets\mgw',
                                                 SEQ_LEN,
                                                 transform,
                                                 buffer_size=30),
                          batch_size=BSIZE,
                          num_workers=0,
                          shuffle=True)
test_loader = DataLoader(RolloutSequenceDataset('D:\steps1000\datasets\mgw',
                                                SEQ_LEN,
                                                transform,
                                                train=False,
                                                buffer_size=10),
                         batch_size=BSIZE,
                         num_workers=0)


def to_latent(obs, next_obs):
    """ Transform observations to latent space.
Exemple #12
0
def train_mdrnn(logdir, traindir, epochs=10, testdir=None):
    BSIZE = 80 # maybe should change this back to their initial one of 16
    noreload = False #Best model is not reloaded if specified
    SEQ_LEN = 32
    epochs = int(epochs)

    testdir = testdir if testdir else traindir
    cuda = torch.cuda.is_available()

    torch.manual_seed(123)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


    # Loading VAE
    vae_file = join(logdir, 'vae', 'best.tar')
    assert exists(vae_file), "No trained VAE in the logdir..."
    state = torch.load(vae_file)
    print("Loading VAE at epoch {} "
          "with test error {}".format(
              state['epoch'], state['precision']))

    vae = VAE(3, LSIZE).to(device)
    vae.load_state_dict(state['state_dict'])

    # Loading model
    rnn_dir = join(logdir, 'mdrnn')
    rnn_file = join(rnn_dir, 'best.tar')

    if not exists(rnn_dir):
        mkdir(rnn_dir)

    mdrnn = MDRNN(LSIZE, ASIZE, RSIZE, 5)
    mdrnn.to(device)
    optimizer = torch.optim.RMSprop(mdrnn.parameters(), lr=1e-3, alpha=.9)
    scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)
    earlystopping = EarlyStopping('min', patience=30)


    if exists(rnn_file) and not noreload:
        rnn_state = torch.load(rnn_file)
        print("Loading MDRNN at epoch {} "
              "with test error {}".format(
                  rnn_state["epoch"], rnn_state["precision"]))
        mdrnn.load_state_dict(rnn_state["state_dict"])
        optimizer.load_state_dict(rnn_state["optimizer"])
        scheduler.load_state_dict(state['scheduler'])
        earlystopping.load_state_dict(state['earlystopping'])


    # Data Loading
    transform = transforms.Lambda(
        lambda x: np.transpose(x, (0, 3, 1, 2)) / 255)
    train_loader = DataLoader(
        RolloutSequenceDataset(traindir, SEQ_LEN, transform, buffer_size=30),
        batch_size=BSIZE, num_workers=8, shuffle=True)
    test_loader = DataLoader(
        RolloutSequenceDataset(testdir, SEQ_LEN, transform, train=False, buffer_size=10),
        batch_size=BSIZE, num_workers=8)

    def to_latent(obs, next_obs):
        """ Transform observations to latent space.

        :args obs: 5D torch tensor (BSIZE, SEQ_LEN, ASIZE, SIZE, SIZE)
        :args next_obs: 5D torch tensor (BSIZE, SEQ_LEN, ASIZE, SIZE, SIZE)

        :returns: (latent_obs, latent_next_obs)
            - latent_obs: 4D torch tensor (BSIZE, SEQ_LEN, LSIZE)
            - next_latent_obs: 4D torch tensor (BSIZE, SEQ_LEN, LSIZE)
        """
        with torch.no_grad():
            obs, next_obs = [
                f.upsample(x.view(-1, 3, SIZE, SIZE), size=RED_SIZE,
                           mode='bilinear', align_corners=True)
                for x in (obs, next_obs)]

            (obs_mu, obs_logsigma), (next_obs_mu, next_obs_logsigma) = [
                vae(x)[1:] for x in (obs, next_obs)]

            latent_obs, latent_next_obs = [
                (x_mu + x_logsigma.exp() * torch.randn_like(x_mu)).view(BSIZE, SEQ_LEN, LSIZE)
                for x_mu, x_logsigma in
                [(obs_mu, obs_logsigma), (next_obs_mu, next_obs_logsigma)]]
        return latent_obs, latent_next_obs

    def get_loss(latent_obs, action, reward, terminal, latent_next_obs):
        """ Compute losses.

        The loss that is computed is:
        (GMMLoss(latent_next_obs, GMMPredicted) + MSE(reward, predicted_reward) +
             BCE(terminal, logit_terminal)) / (LSIZE + 2)
        The LSIZE + 2 factor is here to counteract the fact that the GMMLoss scales
        approximately linearily with LSIZE. All losses are averaged both on the
        batch and the sequence dimensions (the two first dimensions).

        :args latent_obs: (BSIZE, SEQ_LEN, LSIZE) torch tensor
        :args action: (BSIZE, SEQ_LEN, ASIZE) torch tensor
        :args reward: (BSIZE, SEQ_LEN) torch tensor
        :args latent_next_obs: (BSIZE, SEQ_LEN, LSIZE) torch tensor

        :returns: dictionary of losses, containing the gmm, the mse, the bce and
            the averaged loss.
        """
        latent_obs, action,\
            reward, terminal,\
            latent_next_obs = [arr.transpose(1, 0)
                               for arr in [latent_obs, action,
                                           reward, terminal,
                                           latent_next_obs]]
        mus, sigmas, logpi, rs, ds = mdrnn(action, latent_obs)
        gmm = gmm_loss(latent_next_obs, mus, sigmas, logpi)
        bce = f.binary_cross_entropy_with_logits(ds, terminal)
        mse = f.mse_loss(rs, reward)
        loss = (gmm + bce + mse) / (LSIZE + 2)
        return dict(gmm=gmm, bce=bce, mse=mse, loss=loss)


    def data_pass(epoch, train): # pylint: disable=too-many-locals
        """ One pass through the data """
        if train:
            mdrnn.train()
            loader = train_loader
        else:
            mdrnn.eval()
            loader = test_loader

        loader.dataset.load_next_buffer()

        cum_loss = 0
        cum_gmm = 0
        cum_bce = 0
        cum_mse = 0

        pbar = tqdm(total=len(loader.dataset), desc="Epoch {}".format(epoch))
        for i, data in enumerate(loader):
            obs, action, reward, terminal, next_obs = [arr.to(device) for arr in data]

            # transform obs
            latent_obs, latent_next_obs = to_latent(obs, next_obs)

            if train:
                losses = get_loss(latent_obs, action, reward,
                                  terminal, latent_next_obs)

                optimizer.zero_grad()
                losses['loss'].backward()
                optimizer.step()
            else:
                with torch.no_grad():
                    losses = get_loss(latent_obs, action, reward,
                                      terminal, latent_next_obs)

            cum_loss += losses['loss'].item()
            cum_gmm += losses['gmm'].item()
            cum_bce += losses['bce'].item()
            cum_mse += losses['mse'].item()

            pbar.set_postfix_str("loss={loss:10.6f} bce={bce:10.6f} "
                                 "gmm={gmm:10.6f} mse={mse:10.6f}".format(
                                     loss=cum_loss / (i + 1), bce=cum_bce / (i + 1),
                                     gmm=cum_gmm / LSIZE / (i + 1), mse=cum_mse / (i + 1)))
            pbar.update(BSIZE)
        pbar.close()
        return cum_loss * BSIZE / len(loader.dataset)

    train = partial(data_pass, train=True)
    test = partial(data_pass, train=False)

    for e in range(epochs):
        cur_best = None
        train(e)
        test_loss = test(e)
        scheduler.step(test_loss)
        earlystopping.step(test_loss)

        is_best = not cur_best or test_loss < cur_best
        if is_best:
            cur_best = test_loss
        checkpoint_fname = join(rnn_dir, 'checkpoint.tar')
        save_checkpoint({
            "state_dict": mdrnn.state_dict(),
            "optimizer": optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'earlystopping': earlystopping.state_dict(),
            "precision": test_loss,
            "epoch": e}, is_best, checkpoint_fname,
                        rnn_file)

        if earlystopping.stop:
            print("End of Training because of early stopping at epoch {}".format(e))
            break
Exemple #13
0
earlystopping = EarlyStopping('min', patience=30)

if exists(rnn_file) and not args.noreload:
    rnn_state = torch.load(rnn_file)
    print("Loading MDRNN at epoch {} "
          "with test error {}".format(rnn_state["epoch"],
                                      rnn_state["precision"]))
    mdrnn.load_state_dict(rnn_state["state_dict"])
    optimizer.load_state_dict(rnn_state["optimizer"])
    scheduler.load_state_dict(state['scheduler'])
    earlystopping.load_state_dict(state['earlystopping'])

# Data Loading
transform = transforms.Lambda(lambda x: np.transpose(x, (0, 3, 1, 2)) / 255)
train_loader = DataLoader(RolloutSequenceDataset(args.traindir,
                                                 SEQ_LEN,
                                                 transform,
                                                 buffer_size=30),
                          batch_size=BSIZE,
                          num_workers=8,
                          shuffle=True)
test_loader = DataLoader(RolloutSequenceDataset(args.testdir,
                                                SEQ_LEN,
                                                transform,
                                                train=False,
                                                buffer_size=10),
                         batch_size=BSIZE,
                         num_workers=8)


def to_latent(obs, next_obs):
    """ Transform observations to latent space.