コード例 #1
0
def animate_ff(model, frames, actions, mean_image, args):
    initial_frames = np.concatenate([
        normalize_frame(frame) - mean_image
        for frame in frames[0:args.skip_frames]
    ])
    initial_action = actions[args.skip_frames - 1].astype(np.float32)

    input_frames, input_action = concat_examples(batch=[(initial_frames,
                                                         initial_action)],
                                                 device=args.gpu)
    print('input_frames shape: {}'.format(input_frames.shape))

    predicted_frames = []
    ground_truths = []
    with chainer.no_backprop_mode():
        for next_index in range(args.skip_frames, len(frames)):
            print('next_index: {}'.format(next_index))
            predicted_frame = model((input_frames, input_action))
            next_frames = F.concat((input_frames, predicted_frame),
                                   axis=1)[:, 3:, :, :]
            next_action = actions[next_index].astype(np.float32)

            input_frames = next_frames
            input_action = to_device(device=args.gpu,
                                     x=next_action.reshape((-1, 3, 1)))
            # print('next_frames shape: {}'.format(input_frames.shape))
            # print('next_action shape: {}'.format(input_action.shape))
            # Keep predicted image
            predicted_frame.to_cpu()
            predicted_frames.append(
                converter.chw2hwc(predicted_frame.data[0] + mean_image))
            ground_truth = frames[next_index]
            ground_truths.append(converter.chw2hwc(ground_truth))

    return ground_truths, predicted_frames
コード例 #2
0
def animate_lstm(model, frames, actions, mean_image, args):
    with chainer.no_backprop_mode():
        for step in range(args.init_steps - 1):
            print('next_index: {}'.format(step))
            frame = frames[step]
            normalized_frame = normalize_frame(frame)
            input_frame = normalized_frame - mean_image
            input_frame = to_device(device=args.gpu,
                                    x=input_frame.reshape((-1, ) +
                                                          input_frame.shape))
            print('input_frame shape: {}'.format(input_frame.shape))
            input_action = actions[step + 1].astype(np.float32)
            input_action = to_device(device=args.gpu,
                                     x=input_action.reshape((-1, 3, 1)))

            model((input_frame, input_action))

    predicted_frames = []
    ground_truths = []

    frame = frames[args.init_steps - 1]
    normalized_frame = normalize_frame(frame)
    next_frame = normalized_frame - mean_image
    next_frame = to_device(device=args.gpu,
                           x=input_frame.reshape((-1, ) + next_frame.shape))

    with chainer.no_backprop_mode():
        for next_index in range(args.init_steps - 1, len(frames) - 1):
            print('next_index: {}'.format(next_index))
            input_frame = next_frame
            input_action = actions[next_index].astype(np.float32)
            input_action = to_device(device=args.gpu,
                                     x=input_action.reshape((-1, 3, 1)))

            predicted_frame = model((input_frame, input_action))

            next_frame = chainer.Variable(predicted_frame.array)
            # Keep predicted image
            predicted_frame.to_cpu()
            predicted_frames.append(
                converter.chw2hwc(predicted_frame.data[0] + mean_image))
            ground_truth = frames[next_index]
            ground_truths.append(converter.chw2hwc(ground_truth))

    return ground_truths, predicted_frames
コード例 #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model-file', type=str, default='')
    parser.add_argument('--mean-image-file',
                        type=str,
                        default='mean_image.pickle')
    parser.add_argument('--dataset-file', type=str, default='100000.pickle')
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--skip-frames', type=int, default=4)
    parser.add_argument('--initial-frame', type=int, default=0)
    parser.add_argument('--last-frame', type=int, default=10000)
    parser.add_argument('--init-steps', type=int, default=11)
    parser.add_argument('--show-dataset', action='store_true')
    parser.add_argument('--show-prediction', action='store_true')
    parser.add_argument('--show-mean-image', action='store_true')
    parser.add_argument('--show-sample-frame', action='store_true')
    parser.add_argument('--lstm', action='store_true')

    args = parser.parse_args()

    if args.show_dataset:
        dataset = load_dataset([args.dataset_file])
        animate_dataset(dataset, args)

    if args.show_prediction:
        dataset = load_dataset([args.dataset_file])
        if args.lstm:
            model = LstmPredictionNetwork()
        else:
            model = FeedForwardPredictionNetwork()
        serializers.load_model(args.model_file, model)
        if not args.gpu < 0:
            model.to_gpu()

        mean_image = ds.load_mean_image(args.mean_image_file)
        animate_predictions(model, dataset, mean_image, args)

    if args.show_mean_image:
        mean_image = ds.load_mean_image(args.mean_image_file)
        viewer.show_image(converter.chw2hwc(mean_image), title='mean image')

    if args.show_sample_frame:
        dataset = load_dataset([args.dataset_file])
        frame = dataset['frames'][args.initial_frame]
        viewer.show_image(converter.chw2hwc(frame), title='frame')
コード例 #4
0
def verify_batch(batch, mean_image, args):
    examples = ds.concat_examples_batch(batch)
    (train_frame, train_actions) = examples['input']
    target_frames = examples['target']
    print_shape(train_frame, 'train frame shape: ')
    if args.lstm:
        assert train_frame[0][0].shape == (args.color_channels, IMAGE_HEIGHT,
                                           IMAGE_WIDTH)
        assert train_actions.shape == (args.batch_size, args.unroll_steps,
                                       args.num_actions, 1)
        assert target_frames[0][0].shape == (args.color_channels, IMAGE_HEIGHT,
                                             IMAGE_WIDTH)
    else:
        assert train_frame.shape == (args.batch_size,
                                     args.color_channels * args.skip_frames,
                                     IMAGE_HEIGHT, IMAGE_WIDTH)
        assert train_actions.shape == (args.batch_size, args.k_step,
                                       args.num_actions, 1)
        assert target_frames[0][0].shape == (args.color_channels, IMAGE_HEIGHT,
                                             IMAGE_WIDTH)

    images = []
    titles = []
    if args.lstm:
        for i in range(args.init_steps):
            images.append(converter.chw2hwc(train_frame[0][i] + mean_image))
            titles.append('step={}, actions={}, action={}'.format(
                i, train_actions[0][0], np.argmax(train_actions[0][0])))
        for i in range(args.k_step):
            images.append(converter.chw2hwc(target_frames[0][i] + mean_image))
            titles.append('target frame kstep={}'.format(i + 1))
    else:
        for i in range(args.skip_frames):
            images.append(train_frame[0][i * args.color_channels] +
                          mean_image[0])
            titles.append('channel={}, actions={}, action={}'.format(
                i, train_actions[0][0], np.argmax(train_actions[0][0])))
        for i in range(args.k_step):
            images.append(converter.chw2hwc(target_frames[0][i] + mean_image))
            titles.append('target frame kstep={}'.format(i + 1))

    viewer.show_images(images=images, titles=titles)
コード例 #5
0
def animate_dataset(dataset, args):
    initial_frame = args.initial_frame
    last_frame = args.last_frame
    frames = [converter.chw2hwc(frame) for frame in dataset['frames']]
    viewer.animate(frames[initial_frame:last_frame], titles=['dataset'])
コード例 #6
0
    def test_chw2hwc(self):
        image = np.ndarray(shape=(3, 28, 28), dtype=np.float32)
        converted = converter.chw2hwc(image)

        assert converted.shape == (28, 28, 3)