Beispiel #1
0
def main(env_str=None, visualize=True, f=None):

    # Framework Setup #
    # ----------------#

    # choose random framework
    f = choose_random_framework() if f is None else f
    ivy.set_framework(f)

    # get environment
    env = getattr(ivy_gym, env_str)()

    # run environment steps
    env.reset()
    ac_dim = env.action_space.shape[0]
    for _ in range(250):
        ac = ivy.random_uniform(-1, 1, (ac_dim, ))
        env.step(ac)
        if visualize:
            env.render()
    env.close()
    ivy.unset_framework()

    # message
    print('End of Run Through Demo!')
Beispiel #2
0
 def _worker_fn(index_queue, output_queue, dataset, numpy_loading):
     while True:
         try:
             slice_obj = index_queue.get(timeout=1.0)
         except queue.Empty:
             continue
         if slice_obj is None:
             dataset.close()
             return
         if numpy_loading:
             ivy.set_framework('numpy')
         item = Dataset._slice_dataset_with_error_checks(dataset, slice_obj)
         if numpy_loading:
             ivy.unset_framework()
         if ivy.wrapped_mode():
             item = item.to_native(nested=True)
         output_queue.put(item.to_dict())
Beispiel #3
0
        dataset_dirs_class=ExampleDatasetDirs,
        dataset_spec_args=dataset_spec_args,
        dataset_spec_class=ExampleDatasetSpec,
        data_loader_spec_args=data_loader_spec_args,
        data_loader_spec_class=ExampleDataLoaderSpec,
        network_spec_args=network_spec_args,
        network_spec_class=ExampleNetworkSpec,
        trainer_spec_args=trainer_spec_args,
        spec_cont=ivy.Container({'trainer': {
            'compile_mode': compile_mode
        }}))
    trainer.setup()
    print("Finished complete example!")
    trainer.train()
    trainer.close()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--framework',
        type=str,
        default=None,
        help=
        'which framework to use. Chooses a random framework if unspecified.')
    parsed_args = parser.parse_args()
    f = ivy.default(parsed_args.framework, ivy.choose_random_framework())
    ivy.set_framework(f)
    main()
    ivy.unset_framework()
Beispiel #4
0
 def __getitem__(self, slice_obj):
     if not self._workers_initialized:
         self._initialize_all_workers()
     if self._numpy_loading:
         ivy.set_framework('numpy')
     if self._num_processes < 2 or isinstance(slice_obj, numbers.Number):
         ret = self._get_item(slice_obj)
         if self._numpy_loading:
             ivy.unset_framework()
         self._first_pass = False
         return ret
     slice_size = int(round(slice_obj.stop - slice_obj.start))
     num_sub_slices = min(slice_size, self._num_processes)
     slice_points = np.linspace(slice_obj.start, slice_obj.stop,
                                num_sub_slices + 1)
     slice_sizes = np.round(slice_points[1:] - slice_points[:-1]).astype(
         np.int32)
     if Dataset._is_int(slice_obj.start) and Dataset._is_int(
             slice_obj.stop):
         slice_points = np.round(slice_points)
     sub_slices = [
         slice(slice_points[i], slice_points[i + 1], 1.)
         for i in range(num_sub_slices)
     ]
     if self._prefetching:
         self._queue_offset = int(not self._queue_offset)
     else:
         self._queue_offset = np.random.randint(0, self._num_processes)
     q_idxs = [
         int((i + self._queue_offset) % self._num_processes)
         for i in range(len(sub_slices))
     ]
     slice_queues = [self._slice_queues[qi] for qi in q_idxs]
     output_queues = [self._output_queues[qi] for qi in q_idxs]
     if self._prefetching:
         if self._first_pass:
             [
                 slice_queue.put(sub_slice)
                 for slice_queue, sub_slice in zip(slice_queues, sub_slices)
             ]
         else:
             slice_queues[-1].put(sub_slices[-1])
         if self._numpy_loading:
             ivy.unset_framework()
         self._first_pass = False
         return ivy.Container(queues=output_queues,
                              queue_load_sizes=slice_sizes,
                              queue_timeout=self._queue_timeout)
     else:
         [
             slice_queue.put(sub_slice)
             for slice_queue, sub_slice in zip(slice_queues, sub_slices)
         ]
         if ivy.wrapped_mode():
             items_as_lists = [
                 ivy.Container(output_queue.get(
                     timeout=self._queue_timeout)).to_ivy()
                 for output_queue in output_queues
             ]
         else:
             items_as_lists = [
                 ivy.Container(
                     output_queue.get(timeout=self._queue_timeout))
                 for output_queue in output_queues
             ]
         if self._numpy_loading:
             ivy.unset_framework()
         self._first_pass = False
         return ivy.Container.list_join(items_as_lists)
Beispiel #5
0
def print_json_args(base_dir=None, keys_to_ignore=None, keychains_to_ignore=None):
    if not ivy.exists(base_dir):
        base_dir = os.getcwd()
    ivy.set_framework('numpy')
    parser = argparse.ArgumentParser()
    parser.add_argument('-sd', '--sub_directory', type=str,
                        help='A sub-directory to print the json args for, default is base_dir passed in.')
    parser.add_argument('-dd', '--diff_directory', type=str,
                        help='The directory from which to compare the difference in specifications.')
    parser.add_argument('-kti', '--keys_to_ignore', type=str, default=keys_to_ignore,
                        help='Keys to ignore when printing the specification.')
    parser.add_argument('-kcti', '--keychains_to_ignore', type=str, default=keychains_to_ignore,
                        help='Key-chains to ignore when printing the specification.')
    parser.add_argument('-kcts', '--keychain_to_show', type=str,
                        help='The key-chain to show. Default is None, in which case all key-chains are shown.')
    parser.add_argument('-sn', '--spec_names', type=str,
                        help='The specification names for the json files. Default is ivy_builder defaults of'
                             '[ dataset_dirs | dataset | data_loader| network | trainer |]')
    parser.add_argument('-d', '--show_defaults', action='store_true',
                        help='Whether to show the default json arguments.'
                             'If unset then the current arguments are shown, not the defaut values.')
    parser.add_argument('-c', '--current_dir_only', action='store_true',
                        help='Whether to only show the json arguments for the current directory,'
                             'without searching through parent directories also.')
    parser.add_argument('-sdo', '--show_diff_only', action='store_true',
                        help='Whether to only show the difference between the current directory'
                             'and the diff directory.')
    parser.add_argument('-sso', '--show_same_only', action='store_true',
                        help='Whether to only show the same entries between the current directory'
                             'and the diff directory.')
    parsed_args = parser.parse_args()
    if (parsed_args.show_diff_only or parsed_args.show_same_only) and not parsed_args.diff_directory:
        raise Exception('show_diff_only and show_same_only flags are only applicable if diff_directory is set.')
    if parsed_args.show_diff_only and parsed_args.show_same_only:
        raise Exception('show_diff_only and show_same_only cannot both be set, please choose one to set.')
    if ivy.exists(parsed_args.spec_names):
        spec_names = [kc[1:-1] for kc in ''.join(parsed_args.spec_names[1:-1]).split(', ')]
    else:
        spec_names = None
    if ivy.exists(parsed_args.sub_directory):
        sub_dir = os.path.normpath(os.path.join(base_dir, parsed_args.sub_directory))
    else:
        sub_dir = base_dir
    if ivy.exists(parsed_args.keys_to_ignore):
        keys_to_ignore = [kc[1:-1] for kc in ''.join(parsed_args.keys_to_ignore[1:-1]).split(', ')]
    else:
        keys_to_ignore = list()
    if ivy.exists(parsed_args.keychains_to_ignore):
        keychains_to_ignore = [kc[1:-1] for kc in ''.join(parsed_args.keychains_to_ignore[1:-1]).split(',')]
    else:
        keychains_to_ignore = list()
    these_json_args = get_json_args(
        sub_dir, keys_to_ignore, keychains_to_ignore, parsed_args.keychain_to_show, parsed_args.show_defaults,
        store_duplicates=True, current_dir_only=parsed_args.current_dir_only, spec_names=spec_names)
    if ivy.exists(parsed_args.diff_directory):
        other_sub_dir = os.path.normpath(os.path.join(base_dir, parsed_args.diff_directory))
        if other_sub_dir == sub_dir:
            raise Exception('Invalid diff_directory {} selected, it is the same as the sub_directory {}.'.format(
                other_sub_dir, sub_dir))
        other_json_args = get_json_args(
            other_sub_dir, keys_to_ignore, keychains_to_ignore, parsed_args.keychain_to_show, parsed_args.show_defaults,
            store_duplicates=True, current_dir_only=parsed_args.current_dir_only, spec_names=spec_names)
        diff_keys = 'diff'
        for sub_folder, other_sub_folder in zip(sub_dir.split('/'), other_sub_dir.split('/')):
            if sub_folder != other_sub_folder:
                diff_keys = [sub_folder, other_sub_folder]
                break
        if parsed_args.show_diff_only:
            mode = 'diff_only'
        elif parsed_args.show_same_only:
            mode = 'same_only'
        else:
            mode = 'all'
        diff_json_args = ivy.Container.diff(these_json_args, other_json_args, mode=mode, diff_keys=diff_keys)
        keyword_color_dict = {'duplicated': 'magenta'}
        if isinstance(diff_keys, list):
            diff_keys_dict = dict(zip(diff_keys, ['red'] * 2))
            keyword_color_dict = {**keyword_color_dict, **diff_keys_dict}
        print(ivy.Container(diff_json_args, keyword_color_dict=keyword_color_dict))
    else:
        print(ivy.Container(these_json_args, keyword_color_dict={'duplicated': 'magenta'}))
    ivy.unset_framework()
Beispiel #6
0
 def cleanup(self):
     self._trainer.close()
     ivy.unset_framework()
Beispiel #7
0
    def tune(self):

        # Create Trainable class #
        # -----------------------#

        # builder for TuneTrainable
        builder = self._builder

        # classes and args for TuneTrainable
        data_loader_class = self._data_loader_class
        network_class = self._network_class
        trainer_class = self._trainer_class
        dataset_dirs_args = self._dataset_dirs_args
        dataset_dirs_class = self._dataset_dirs_class
        dataset_dirs = self._dataset_dirs
        dataset_spec_args = self._dataset_spec_args
        dataset_spec_class = self._dataset_spec_class
        dataset_spec = self._dataset_spec
        data_loader_spec_args = self._data_loader_spec_args
        data_loader_spec_class = self._data_loader_spec_class
        data_loader_spec = self._data_loader_spec
        data_loader = self._data_loader
        network_spec_args = self._network_spec_args
        network_spec_class = self._network_spec_class
        network_spec = self._network_spec
        network = self._network
        trainer_spec_args = self._trainer_spec_args
        trainer_spec_class = self._trainer_spec_class
        trainer_spec = self._trainer_spec
        json_spec_path = self._json_spec_path
        spec_cont = self._spec_cont
        orig_log_dir = self._spec.trainer.spec.log_dir

        # noinspection PyAttributeOutsideInit
        class TuneTrainable(tune.Trainable):
            def setup(self, _):
                ivy.set_framework(self.config['framework'])
                self._train_steps_per_tune_step = self.config[
                    'train_steps_per_tune_step']
                config_cont = Container(self.config)
                self._config_str = '_'.join([
                    str(SHORT_SPEC_KEYS_DICT[kc.split('/')[0]]) + '_' +
                    kc.split('/')[-1] + '_' +
                    ("%.2g" % val if isinstance(val, float) else str(val))
                    for kc, val in config_cont.to_iterator()
                    if (isinstance(val, (float, int, bool, type(None)))
                        and kc not in FIXED_CONFIG_KEYS)
                ])
                trainer_spec_args['log_dir'] = os.path.join(
                    orig_log_dir, self._config_str)
                new_args = dict()
                for class_key, args in zip(SPEC_KEYS, [
                        dataset_dirs_args, dataset_spec_args,
                        data_loader_spec_args, network_spec_args,
                        trainer_spec_args
                ]):
                    new_args[class_key] =\
                        Container({**args, **(self.config[class_key] if
                                              class_key in self.config else {})}).prune_key_from_key_chains(
                            containing='_AND_')

                self._trainer = builder.build_trainer(
                    data_loader_class=data_loader_class,
                    network_class=network_class,
                    trainer_class=trainer_class,
                    dataset_dirs_args=new_args['dataset_dirs'],
                    dataset_dirs_class=dataset_dirs_class,
                    dataset_dirs=dataset_dirs,
                    dataset_spec_args=new_args['dataset_spec'],
                    dataset_spec_class=dataset_spec_class,
                    dataset_spec=dataset_spec,
                    data_loader_spec_args=new_args['data_loader_spec'],
                    data_loader_spec_class=data_loader_spec_class,
                    data_loader_spec=data_loader_spec,
                    data_loader=data_loader,
                    network_spec_args=new_args['network_spec'],
                    network_spec_class=network_spec_class,
                    network_spec=network_spec,
                    network=network,
                    trainer_spec_args=new_args['trainer_spec'],
                    trainer_spec_class=trainer_spec_class,
                    trainer_spec=trainer_spec,
                    json_spec_path=json_spec_path,
                    spec_cont=spec_cont)
                # unset at_end configs
                self._save_at_end = self._trainer.spec.save_at_end
                self._trainer.spec.save_at_end = False
                self._log_at_end = self._trainer.spec.log_at_end
                self._trainer.spec.log_at_end = False
                self._vis_at_end = self._trainer.spec.vis_at_end
                self._trainer.spec.vis_at_end = False

                self._trainer.setup()
                # noinspection PyProtectedMember
                self._trainer_global_step = self._trainer._starting_iteration
                self._trainer_total_iterations = self._trainer.spec.total_iterations
                self.timestep = int(
                    math.floor(self._trainer_global_step /
                               self._train_steps_per_tune_step))

            # noinspection PyProtectedMember
            def step(self):
                total_iterations = min(
                    self._trainer_global_step +
                    self._train_steps_per_tune_step,
                    self._trainer_total_iterations)
                self._trainer_global_step = self._trainer.train(
                    self._trainer_global_step, total_iterations)
                self.timestep += 1
                ret_dict = {
                    'timestep': self.timestep,
                    'cost': ivy.to_numpy(self._trainer.moving_average_loss)
                }
                if self._trainer_global_step >= self._trainer_total_iterations:
                    if self._save_at_end:
                        self._trainer._save()
                    if self._log_at_end and ivy.exists(
                            self._trainer._training_batch):
                        self._trainer._log_scalars()
                    if self._vis_at_end:
                        dl = self._trainer.spec.data_loader
                        net = self._trainer.spec.network
                        tb = self._trainer._training_batch
                        gs = self._trainer._global_step
                        self._trainer._write_image_summaries(dl, net, tb, gs)
                    ret_dict[tune.result.DONE] = True
                return ret_dict

            def save_checkpoint(self, checkpoint_dir):
                os.makedirs(checkpoint_dir, exist_ok=True)
                save_name = 'step_{}'.format(self.timestep)
                save_path = os.path.join(checkpoint_dir, save_name)
                self._trainer.save(save_path)
                print('saved checkpoint to path: {}'.format(save_path))
                return save_path

            def load_checkpoint(self, checkpoint_path):
                self._trainer.restore(checkpoint_path,
                                      self._trainer_global_step)
                print('loaded checkpoint from {}'.format(checkpoint_path))

            def cleanup(self):
                self._trainer.close()
                ivy.unset_framework()

        # Run this trainable class #
        # -------------------------#

        max_t = int(
            np.ceil(self._spec.trainer.spec.total_iterations /
                    self._spec.train_steps_per_tune_step))
        ahb = AsyncHyperBandScheduler(
            time_attr="timestep",
            metric="cost",
            mode="min",
            grace_period=max_t
            if self._spec.grace_period == -1 else self._spec.grace_period,
            max_t=max_t)

        num_cpus = multiprocessing.cpu_count()
        assert num_cpus > 0
        num_gpus = ivy.num_gpus()
        cpus_per_trial = num_cpus / self._spec.parallel_trials
        gpus_per_trial = num_gpus / self._spec.parallel_trials
        if self._spec.device_priority == 'cpu' or num_gpus == 0:
            cpus_per_trial = int(round(
                cpus_per_trial)) if cpus_per_trial > 1 else cpus_per_trial
            parallel_trials = math.floor(num_cpus / cpus_per_trial)
            gpus_per_trial = num_gpus / parallel_trials
            gpus_per_trial = math.floor(
                gpus_per_trial) if gpus_per_trial > 1 else gpus_per_trial
        elif self._spec.device_priority == 'gpu':
            gpus_per_trial = int(round(
                gpus_per_trial)) if gpus_per_trial > 1 else gpus_per_trial
            parallel_trials = math.floor(num_gpus / gpus_per_trial)
            cpus_per_trial = num_cpus / parallel_trials
            cpus_per_trial = math.floor(
                cpus_per_trial) if cpus_per_trial > 1 else cpus_per_trial
        else:
            raise Exception(
                'device_priority must be one of [ cpu | gpu ], but found {}'.
                format(self._spec.device_priority))
        ivy.unset_framework()

        reporter = CLIReporter(['cost'])

        # initialize ray with custom temp_dir
        ray.init(_temp_dir=os.path.join(
            '/'.join(self._spec.trainer.spec.log_dir.split('/')[:-1]), 'ray'),
                 ignore_reinit_error=True)

        return tune.run(
            TuneTrainable,
            progress_reporter=reporter,
            name=self._spec.name,
            scheduler=ahb,
            stop={
                "timestep":
                int(
                    np.ceil(self._spec.trainer.spec.total_iterations /
                            self._spec.train_steps_per_tune_step))
            },
            num_samples=self._spec.num_samples,
            resources_per_trial={
                "cpu": cpus_per_trial,
                "gpu": gpus_per_trial
            },
            config={
                key: val
                for key, val in self._spec.items() if
                (isinstance(val, dict) or isinstance(val, tune.sample.Function)
                 or key in ['framework', 'train_steps_per_tune_step'])
            },
            local_dir='/'.join(
                self._spec.trainer.spec.log_dir.split('/')[:-1]),
            checkpoint_freq=self._spec.checkpoint_freq,
            checkpoint_at_end=True)
Beispiel #8
0
    def __init__(
            self,
            data_loader_class,
            network_class,
            trainer_class,
            dataset_dirs_args: dict = None,
            dataset_dirs_class: DatasetDirs.__base__ = DatasetDirs,
            dataset_dirs: DatasetDirs = None,
            dataset_spec_args: dict = None,
            dataset_spec_class: DatasetSpec.__base__ = DatasetSpec,
            dataset_spec: DatasetSpec = None,
            data_loader_spec_args: dict = None,
            data_loader_spec_class: DataLoaderSpec.__base__ = DataLoaderSpec,
            data_loader_spec: DataLoaderSpec = None,
            data_loader=None,
            network_spec_args: dict = None,
            network_spec_class: NetworkSpec.__base__ = NetworkSpec,
            network_spec: NetworkSpec = None,
            network=None,
            trainer_spec_args: dict = None,
            trainer_spec_class: TrainerSpec.__base__ = TrainerSpec,
            trainer_spec: TrainerSpec = None,
            trainer=None,
            tuner_spec_args: dict = None,
            tuner_spec_class: TunerSpec.__base__ = TunerSpec,
            tuner_spec: TunerSpec = None,
            json_spec_path: str = None,
            spec_cont: dict = None):
        """
        base class for any tune trainers
        """
        if not ivy.exists(tune):
            raise Exception(
                'ray[tune] is needed in order to use the Tuner class, but it is not installed.'
                'Please install via pip install ray[tune]')
        self._data_loader_class = data_loader_class
        self._network_class = network_class
        self._trainer_class = trainer_class
        self._dataset_dirs_args = ivy.default(dataset_dirs_args, dict())
        self._dataset_dirs_class = dataset_dirs_class
        self._dataset_dirs = dataset_dirs
        self._dataset_spec_args = ivy.default(dataset_spec_args, dict())
        self._dataset_spec_class = dataset_spec_class
        self._dataset_spec = dataset_spec
        self._data_loader_spec_args = ivy.default(data_loader_spec_args,
                                                  dict())
        self._data_loader_spec_class = data_loader_spec_class
        self._data_loader_spec = data_loader_spec
        self._data_loader = data_loader
        self._network_spec_args = ivy.default(network_spec_args, dict())
        self._network_spec_class = network_spec_class
        self._network_spec = network_spec
        self._network = network
        self._trainer_spec_args = ivy.default(trainer_spec_args, dict())
        self._trainer_spec_class = trainer_spec_class
        self._trainer_spec = trainer_spec
        self._trainer = trainer
        self._tuner_spec_args = ivy.default(tuner_spec_args, dict())
        self._tuner_spec_class = tuner_spec_class
        self._tuner_spec = tuner_spec
        self._json_spec_path = json_spec_path
        self._spec_cont = spec_cont

        # initialized on _setup
        self._trainer = None

        # builder
        while len(ivy.framework_stack) > 0:
            logging.info(
                'unsetting framework {}, framework stack must be empty when'
                'initializing tuner class.'.format(ivy.framework_stack[-1]))
            ivy.unset_framework()
        self._builder = builder_module

        # tuner spec
        self._spec = self._builder.build_tuner_spec(
            data_loader_class=self._data_loader_class,
            network_class=self._network_class,
            trainer_class=self._trainer_class,
            dataset_dirs_args=self._dataset_dirs_args,
            dataset_dirs_class=self._dataset_dirs_class,
            dataset_dirs=self._dataset_dirs,
            dataset_spec_args=self._dataset_spec_args,
            dataset_spec_class=self._dataset_spec_class,
            dataset_spec=self._dataset_spec,
            data_loader_spec_args=self._data_loader_spec_args,
            data_loader_spec_class=self._data_loader_spec_class,
            data_loader_spec=self._data_loader_spec,
            data_loader=self._data_loader,
            network_spec_args=self._network_spec_args,
            network_spec_class=self._network_spec_class,
            network_spec=self._network_spec,
            network=self._network,
            trainer_spec_args=self._trainer_spec_args,
            trainer_spec_class=self._trainer_spec_class,
            trainer_spec=self._trainer_spec,
            trainer=self._trainer,
            tuner_spec_args=self._tuner_spec_args,
            tuner_spec_class=self._tuner_spec_class,
            json_spec_path=self._json_spec_path,
            spec_cont=self._spec_cont)
        self._spec = _convert_tuner_spec(self._spec)
Beispiel #9
0
def main():

    # LSTM #
    # -----#

    # using the Ivy LSTM memory module, dual stacked, in a PyTorch model

    class TorchModelWithLSTM(torch.nn.Module):
        def __init__(self, channels_in, channels_out):
            torch.nn.Module.__init__(self)
            self._linear = torch.nn.Linear(channels_in, 64)
            self._lstm = ivy_mem.LSTM(64, channels_out, 2, return_state=False)
            self._assign_variables()

        def _assign_variables(self):
            self._lstm.v.map(lambda x, kc: self.register_parameter(
                name=kc, param=torch.nn.Parameter(x)))
            self._lstm.v = self._lstm.v.map(lambda x, kc: self._parameters[kc])

        def forward(self, x):
            x = self._linear(x)
            return self._lstm(x)

    # create model
    in_channels = 32
    out_channels = 8
    ivy.set_framework('torch')
    model = TorchModelWithLSTM(in_channels, out_channels)

    # define inputs
    batch_shape = [1, 2]
    timesteps = 3
    input_shape = batch_shape + [timesteps, in_channels]
    input_seq = torch.rand(batch_shape + [timesteps, in_channels])

    # call model and test output
    output_seq = model(input_seq)
    assert input_seq.shape[:-1] == output_seq.shape[:-1]
    assert input_seq.shape[-1] == in_channels
    assert output_seq.shape[-1] == out_channels

    # define loss function
    target = torch.zeros_like(output_seq)

    def loss_fn():
        pred = model(input_seq)
        return torch.sum((pred - target)**2)

    # define optimizer
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

    # train model
    print('\ntraining dummy PyTorch LSTM model...\n')
    for i in range(10):
        loss = loss_fn()
        loss.backward()
        optimizer.step()
        print('step {}, loss = {}'.format(i, loss))
    print('\ndummy PyTorch LSTM model trained!\n')
    ivy.unset_framework()

    # NTM #
    # ----#

    # using the Ivy NTM memory module in a TensorFlow model

    class TfModelWithNTM(tf.keras.Model):
        def __init__(self, channels_in, channels_out):
            tf.keras.Model.__init__(self)
            self._linear = tf.keras.layers.Dense(64)
            memory_size = 4
            memory_vector_dim = 1
            self._ntm = ivy_mem.NTM(input_dim=64,
                                    output_dim=channels_out,
                                    ctrl_output_size=channels_out,
                                    ctrl_layers=1,
                                    memory_size=memory_size,
                                    memory_vector_dim=memory_vector_dim,
                                    read_head_num=1,
                                    write_head_num=1)
            self._assign_variables()

        def _assign_variables(self):
            self._ntm.v.map(
                lambda x, kc: self.add_weight(name=kc, shape=x.shape))
            self.set_weights(
                [ivy.to_numpy(v) for k, v in self._ntm.v.to_iterator()])
            self.trainable_weights_dict = dict()
            for weight in self.trainable_weights:
                self.trainable_weights_dict[weight.name] = weight
            self._ntm.v = self._ntm.v.map(
                lambda x, kc: self.trainable_weights_dict[kc + ':0'])

        def call(self, x, **kwargs):
            x = self._linear(x)
            return self._ntm(x)

    # create model
    in_channels = 32
    out_channels = 8
    ivy.set_framework('tensorflow')
    model = TfModelWithNTM(in_channels, out_channels)

    # define inputs
    batch_shape = [1, 2]
    timesteps = 3
    input_shape = batch_shape + [timesteps, in_channels]
    input_seq = tf.random.uniform(batch_shape + [timesteps, in_channels])

    # call model and test output
    output_seq = model(input_seq)
    assert input_seq.shape[:-1] == output_seq.shape[:-1]
    assert input_seq.shape[-1] == in_channels
    assert output_seq.shape[-1] == out_channels

    # define loss function
    target = tf.zeros_like(output_seq)

    def loss_fn():
        pred = model(input_seq)
        return tf.reduce_sum((pred - target)**2)

    # define optimizer
    optimizer = tf.keras.optimizers.Adam(1e-2)

    # train model
    print('\ntraining dummy TensorFlow NTM model...\n')
    for i in range(10):
        with tf.GradientTape() as tape:
            loss = loss_fn()
        grads = tape.gradient(loss, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))
        print('step {}, loss = {}'.format(i, loss))
    print('\ndummy TensorFlow NTM model trained!\n')
    ivy.unset_framework()

    # ESM #
    # ----#

    # using the Ivy ESM memory module in a pure-Ivy model, with a JAX backend
    # ToDo: add pre-ESM conv layers to this demo

    class IvyModelWithESM(ivy.Module):
        def __init__(self, channels_in, channels_out):
            self._channels_in = channels_in
            self._esm = ivy_mem.ESM(omni_image_dims=(16, 32))
            self._linear = ivy_mem.Linear(channels_in, channels_out)
            ivy.Module.__init__(self, 'cpu')

        def _forward(self, obs):
            mem = self._esm(obs)
            x = ivy.reshape(mem.mean, (-1, self._channels_in))
            return self._linear(x)

    # create model
    in_channels = 32
    out_channels = 8
    ivy.set_framework('torch')
    model = IvyModelWithESM(in_channels, out_channels)

    # input config
    batch_size = 1
    image_dims = [5, 5]
    num_timesteps = 2
    num_feature_channels = 3

    # create image of pixel co-ordinates
    uniform_pixel_coords =\
        ivy_vision.create_uniform_pixel_coords_image(image_dims, [batch_size, num_timesteps])

    # define camera measurement
    depths = ivy.random_uniform(shape=[batch_size, num_timesteps] +
                                image_dims + [1])
    ds_pixel_coords = ivy_vision.depth_to_ds_pixel_coords(depths)
    inv_calib_mats = ivy.random_uniform(
        shape=[batch_size, num_timesteps, 3, 3])
    cam_coords = ivy_vision.ds_pixel_to_cam_coords(ds_pixel_coords,
                                                   inv_calib_mats)[..., 0:3]
    features = ivy.random_uniform(shape=[batch_size, num_timesteps] +
                                  image_dims + [num_feature_channels])
    img_mean = ivy.concatenate((cam_coords, features), -1)
    cam_rel_mat = ivy.identity(4, batch_shape=[batch_size,
                                               num_timesteps])[..., 0:3, :]

    # place these into an ESM camera measurement container
    esm_cam_meas = ESMCamMeasurement(img_mean=img_mean,
                                     cam_rel_mat=cam_rel_mat)

    # define agent pose transformation
    agent_rel_mat = ivy.identity(4, batch_shape=[batch_size,
                                                 num_timesteps])[..., 0:3, :]

    # collect together into an ESM observation container
    esm_obs = ESMObservation(img_meas={'camera_0': esm_cam_meas},
                             agent_rel_mat=agent_rel_mat)

    # call model and test output
    output = model(esm_obs)
    assert output.shape[-1] == out_channels

    # define loss function
    target = ivy.zeros_like(output)

    def loss_fn(v):
        pred = model(esm_obs, v=v)
        return ivy.reduce_mean((pred - target)**2)

    # optimizer
    optimizer = ivy.SGD(lr=1e-4)

    # train model
    print('\ntraining dummy Ivy ESM model...\n')
    for i in range(10):
        loss, grads = ivy.execute_with_gradients(loss_fn, model.v)
        model.v = optimizer.step(model.v, grads)
        print('step {}, loss = {}'.format(i, ivy.to_numpy(loss).item()))
    print('\ndummy Ivy ESM model trained!\n')
    ivy.unset_framework()

    # message
    print('End of Run Through Demo!')