Beispiel #1
0
    def test_step(self, inputs, beta):
        inputs = self.make_sequences_variable_length(inputs)  #
        actions, seq_lens, mask = inputs['acts'], inputs['seq_lens'], inputs[
            'masks']

        if self.args.gcbc:
            policy = self.step(inputs)
            loss = self.compute_loss(actions, policy, mask, seq_lens)
            log_action_breakdown(policy, actions, mask, seq_lens, self.args.num_distribs is not None, self.dl.quaternion_act, self.valid_position_loss, self.valid_max_position_loss, \
                                 self.valid_rotation_loss, self.valid_max_rotation_loss, self.valid_gripper_loss, self.compute_MAE)
        else:
            enc_policy, plan_policy, encoding, plan = self.step(inputs)
            act_enc_loss = record(
                self.compute_loss(actions, enc_policy, mask, seq_lens),
                self.metrics['valid_act_with_enc_loss'])

            if self.args.discrete:
                loss = act_enc_loss
                log_action_breakdown(enc_policy, actions, mask, seq_lens, self.args.num_distribs is not None, self.dl.quaternion_act, self.metrics['valid_position_loss'], \
                                 self.metrics['valid_max_position_loss'], self.metrics['valid_rotation_loss'], self.metrics['valid_max_rotation_loss'], self.metrics['valid_gripper_loss'], self.compute_MAE)
            else:
                act_plan_loss = record(
                    self.compute_loss(actions, plan_policy, mask, seq_lens),
                    self.metrics['valid_act_with_plan_loss'])
                reg_loss = record(
                    self.compute_regularisation_loss(plan, encoding),
                    self.metrics['valid_reg_loss'])
                loss = act_plan_loss + reg_loss * beta
                log_action_breakdown(plan_policy, actions, mask, seq_lens, self.args.num_distribs is not None, self.dl.quaternion_act, self.metrics['valid_position_loss'], \
                                 self.metrics['valid_max_position_loss'], self.metrics['valid_rotation_loss'], self.metrics['valid_max_rotation_loss'], self.metrics['valid_gripper_loss'], self.compute_MAE)
        return record(loss, self.metrics['valid_loss'])
Beispiel #2
0
    def test_step(self, inputs, beta):
        states, actions, goals, seq_lens, mask = inputs['obs'], inputs[
            'acts'], inputs['goals'], inputs['seq_lens'], inputs['masks']
        ########################### Between here
        if self.args.images:
            imgs, proprioceptive_features, goal_imgs = inputs['imgs'], inputs[
                'proprioceptive_features'], inputs['goal_imgs']
            B, T, H, W, C = imgs.shape
            imgs, goal_imgs = tf.reshape(imgs, [B * T, H, W, C]), tf.reshape(
                goal_imgs, [B * T, H, W, C])
            img_embeddings, goal_embeddings = tf.reshape(
                self.cnn(imgs),
                [B, T, -1]), tf.reshape(self.cnn(goal_imgs), [B, T, -1])

            states = tf.concat(
                [img_embeddings, proprioceptive_features], -1
            )  # gets both the image and it's own xyz ori and angle as pose
            goals = goal_embeddings  # should be B,T, embed_size

        if self.args.gcbc:
            policy = self.actor([states, goals], training=False)
            loss = self.compute_loss(actions, policy, mask, seq_lens)
            log_action_breakdown(policy, actions, mask, seq_lens, self.args.num_distribs is not None, self.dl.quaternion_act, self.valid_position_loss, self.valid_max_position_loss, \
                                 self.valid_rotation_loss, self.valid_max_rotation_loss, self.valid_gripper_loss, self.compute_MAE)
        else:
            encoding = self.encoder([states, actions])
            plan = self.planner(
                [states[:, 0, :], goals[:, 0, :]]
            )  # the final goals are tiled out over the entire non masked sequence, so the first timestep is the final goal.
            z_enc = encoding.sample()
            z_plan = plan.sample()
            z_enc_tiled = tf.tile(tf.expand_dims(z_enc, 1),
                                  (1, self.dl.window_size, 1))
            z_plan_tiled = tf.tile(tf.expand_dims(z_plan, 1),
                                   (1, self.dl.window_size, 1))
            enc_policy = self.actor([states, z_enc_tiled, goals])
            plan_policy = self.actor([states, z_plan_tiled, goals])
            ############### and here could be abstracted into one function
            act_enc_loss = record(
                self.compute_loss(actions, enc_policy, mask, seq_lens),
                self.metrics['valid_act_with_enc_loss'])
            act_plan_loss = record(
                self.compute_loss(actions, plan_policy, mask, seq_lens),
                self.metrics['valid_act_with_plan_loss'])
            reg_loss = record(self.compute_regularisation_loss(plan, encoding),
                              self.metrics['valid_reg_loss'])
            loss = act_plan_loss + reg_loss * beta
            log_action_breakdown(plan_policy, actions, mask, seq_lens, self.args.num_distribs is not None, self.dl.quaternion_act, self.metrics['valid_position_loss'], \
                                 self.metrics['valid_max_position_loss'], self.metrics['valid_rotation_loss'], self.metrics['valid_max_rotation_loss'], self.metrics['valid_gripper_loss'], self.compute_MAE)
        if self.args.gcbc:
            return record(loss, self.metrics['valid_loss'])
        else:
            return record(loss, self.metrics['valid_loss']), z_enc, z_plan
    def test_step(self, **kwargs):
        inputs, beta, lang_labelled_inputs, external_videos = kwargs[
            'batch'], kwargs['beta'], kwargs['lang'], kwargs['video']

        inputs = self.make_sequences_variable_length(inputs)  #
        actions, seq_lens, mask = inputs['acts'], inputs['seq_lens'], inputs[
            'masks']

        if self.args.gcbc:
            policy = self.step(inputs)
            loss = self.compute_loss(actions, policy, mask, seq_lens)
            log_action_breakdown(policy, actions, mask, seq_lens, self.args.num_distribs is not None, self.dl.quaternion_act, self.valid_position_loss, self.valid_max_position_loss, \
                                 self.valid_rotation_loss, self.valid_max_rotation_loss, self.valid_gripper_loss, self.compute_MAE)
        else:
            enc_policy, plan_policy, encoding, plan, indices, actions, mask, seq_lens, sentence_embeddings = self.step(
                inputs, lang_labelled_inputs, external_videos)
            act_enc_loss = record(
                self.compute_loss(actions, enc_policy, mask, seq_lens),
                self.metrics['valid_act_with_enc_loss'])

            if self.args.discrete:
                planner_loss = tf.nn.softmax_cross_entropy_with_logits(
                    labels=tf.stop_gradient(tf.nn.softmax(encoding, -1)),
                    logits=plan)
                record(planner_loss,
                       self.metrics['valid_discrete_planner_loss'])
                loss = act_enc_loss + planner_loss * beta
            else:
                act_plan_loss = record(
                    self.compute_loss(actions, plan_policy, mask, seq_lens),
                    self.metrics['valid_act_with_plan_loss'])
                reg_loss = record(
                    self.compute_regularisation_loss(plan, encoding),
                    self.metrics['valid_reg_loss'])
                loss = act_plan_loss + reg_loss * beta
            log_action_breakdown(plan_policy, actions, mask, seq_lens, self.args.num_distribs is not None, self.dl.quaternion_act, self.metrics['valid_position_loss'], \
                                self.metrics['valid_max_position_loss'], self.metrics['valid_rotation_loss'], self.metrics['valid_max_rotation_loss'], self.metrics['valid_gripper_loss'], self.compute_MAE)
            log_action_breakdown(enc_policy, actions, mask, seq_lens, self.args.num_distribs is not None, self.dl.quaternion_act, self.metrics['valid_enc_position_loss'], \
                                self.metrics['valid_enc_max_position_loss'], self.metrics['valid_enc_rotation_loss'], self.metrics['valid_enc_max_rotation_loss'], self.metrics['valid_enc_gripper_loss'], self.compute_MAE)

            if self.args.use_language:
                # setting probabilistic = false and just passing in the .sample() of the distrib as for some reason slicing it auto samples?
                log_action_breakdown(plan_policy.sample()[indices['unlabelled']:], actions[indices['unlabelled']:], mask[indices['unlabelled']:], seq_lens[indices['unlabelled']:], False, self.dl.quaternion_act,
                self.metrics['valid_lang_position_loss'], self.metrics['valid_lang_max_position_loss'], self.metrics['valid_lang_rotation_loss'], self.metrics['valid_lang_max_rotation_loss'], \
                    self.metrics['valid_lang_gripper_loss'], self.compute_MAE)

        return record(loss, self.metrics['valid_loss'])
Beispiel #4
0
    def train_step(self, inputs, beta):
        inputs = self.make_sequences_variable_length(inputs)

        with tf.GradientTape() as actor_tape, tf.GradientTape(
        ) as encoder_tape, tf.GradientTape() as planner_tape:
            actions, seq_lens, mask = inputs['acts'], inputs[
                'seq_lens'], inputs['masks']

            if self.args.gcbc:
                policy = self.step(inputs)
                loss = self.compute_loss(actions, policy, mask, seq_lens)
                gradients = actor_tape.gradient(loss,
                                                self.actor.trainable_variables)
                self.actor_optimizer.apply_gradients(
                    zip(gradients, self.actor.trainable_variables))
            else:
                enc_policy, plan_policy, encoding, plan = self.step(inputs)
                act_enc_loss = record(
                    self.compute_loss(actions, enc_policy, mask, seq_lens),
                    self.metrics['train_act_with_enc_loss'])
                act_plan_loss = record(
                    self.compute_loss(actions, plan_policy, mask, seq_lens),
                    self.metrics['train_act_with_plan_loss'])
                reg_loss = record(
                    self.compute_regularisation_loss(plan, encoding),
                    self.metrics['train_reg_loss'])
                loss = act_enc_loss + reg_loss * beta

                if self.args.fp16:
                    actor_gradients = self.compute_fp16_grads(
                        self.actor_optimizer, loss, actor_tape, self.actor)
                    encoder_gradients = self.compute_fp16_grads(
                        self.encoder_optimizer, loss, encoder_tape,
                        self.encoder)
                    planner_gradients = self.compute_fp16_grads(
                        self.planner_optimizer, loss, planner_tape,
                        self.planner)
                else:
                    actor_gradients = actor_tape.gradient(
                        loss, self.actor.trainable_variables)
                    encoder_gradients = encoder_tape.gradient(
                        loss, self.encoder.trainable_variables)
                    planner_gradients = planner_tape.gradient(
                        loss, self.planner.trainable_variables)

                actor_norm = record(tf.linalg.global_norm(actor_gradients),
                                    self.metrics['actor_grad_norm'])
                encoder_norm = record(tf.linalg.global_norm(encoder_gradients),
                                      self.metrics['encoder_grad_norm'])
                planner_norm = record(tf.linalg.global_norm(planner_gradients),
                                      self.metrics['planner_grad_norm'])

                gradients = actor_gradients + encoder_gradients + planner_gradients
                record(tf.linalg.global_norm(gradients),
                       self.metrics['global_grad_norm'])

                self.actor_optimizer.apply_gradients(
                    zip(actor_gradients, self.actor.trainable_variables))
                self.encoder_optimizer.apply_gradients(
                    zip(encoder_gradients, self.encoder.trainable_variables))
                self.planner_optimizer.apply_gradients(
                    zip(planner_gradients, self.planner.trainable_variables))

        return record(loss, self.metrics['train_loss'])
Beispiel #5
0
    def train_step(self, inputs, beta):
        with tf.GradientTape() as actor_tape, tf.GradientTape(
        ) as encoder_tape, tf.GradientTape() as planner_tape:
            # Todo: figure out mask and seq_lens for new dataset
            states, actions, goals, seq_lens, mask = inputs['obs'], inputs[
                'acts'], inputs['goals'], inputs['seq_lens'], inputs['masks']

            # Ok, what steps do we need to take
            # 1. When using imagesChange the definition of obs_dim to feature encoder dim + proprioceptive features
            # 2. Reshape imgs to B*T H W C.
            # 3. Sub in for states and goals.
            # 4. THen there should be no further changes!
            if self.args.images:
                imgs, proprioceptive_features, goal_imgs = inputs[
                    'imgs'], inputs['proprioceptive_features'], inputs[
                        'goal_imgs']
                B, T, H, W, C = imgs.shape
                imgs, goal_imgs = tf.reshape(imgs,
                                             [B * T, H, W, C]), tf.reshape(
                                                 goal_imgs, [B * T, H, W, C])
                img_embeddings, goal_embeddings = tf.reshape(
                    self.cnn(imgs),
                    [B, T, -1]), tf.reshape(self.cnn(goal_imgs), [B, T, -1])

                states = tf.concat(
                    [img_embeddings, proprioceptive_features], -1
                )  # gets both the image and it's own xyz ori and angle as pose
                goals = goal_embeddings  # should be B,T, embed_size

            if self.args.gcbc:
                distrib = self.actor([states, goals])
                loss = self.compute_loss(actions, distrib, mask, seq_lens)
                gradients = actor_tape.gradient(loss,
                                                self.actor.trainable_variables)
                self.optimizer.apply_gradients(
                    zip(gradients, self.actor.trainable_variables))
            else:
                encoding = self.encoder([states, actions])
                plan = self.planner(
                    [states[:, 0, :], goals[:, 0, :]]
                )  # the final goals are tiled out over the entire non masked sequence, so the first timestep is the final goal.
                z_enc = encoding.sample()
                z_plan = plan.sample()
                z_enc_tiled = tf.tile(tf.expand_dims(z_enc, 1),
                                      (1, self.dl.window_size, 1))
                z_plan_tiled = tf.tile(tf.expand_dims(z_plan, 1),
                                       (1, self.dl.window_size, 1))

                enc_policy = self.actor([states, z_enc_tiled, goals])
                plan_policy = self.actor([states, z_plan_tiled, goals])

                act_enc_loss = record(
                    self.compute_loss(actions, enc_policy, mask, seq_lens),
                    self.metrics['train_act_with_enc_loss'])
                act_plan_loss = record(
                    self.compute_loss(actions, plan_policy, mask, seq_lens),
                    self.metrics['train_act_with_plan_loss'])
                reg_loss = record(
                    self.compute_regularisation_loss(plan, encoding),
                    self.metrics['train_reg_loss'])
                loss = act_enc_loss + reg_loss * beta

                actor_gradients = actor_tape.gradient(
                    loss, self.actor.trainable_variables)
                encoder_gradients = encoder_tape.gradient(
                    loss, self.encoder.trainable_variables)
                planner_gradients = planner_tape.gradient(
                    loss, self.planner.trainable_variables)

                actor_norm = record(tf.linalg.global_norm(actor_gradients),
                                    self.metrics['actor_grad_norm'])
                encoder_norm = record(tf.linalg.global_norm(encoder_gradients),
                                      self.metrics['encoder_grad_norm'])
                planner_norm = record(tf.linalg.global_norm(planner_gradients),
                                      self.metrics['planner_grad_norm'])

                gradients = actor_gradients + encoder_gradients + planner_gradients
                record(tf.linalg.global_norm(gradients),
                       self.metrics['global_grad_norm'])

                self.actor_optimizer.apply_gradients(
                    zip(actor_gradients, self.actor.trainable_variables))
                self.encoder_optimizer.apply_gradients(
                    zip(encoder_gradients, self.encoder.trainable_variables))
                self.planner_optimizer.apply_gradients(
                    zip(planner_gradients, self.planner.trainable_variables))

        return record(loss, self.metrics['train_loss'])
    def train_step(self, **kwargs):
        inputs, beta, lang_labelled_inputs, external_videos, bulk = kwargs[
            'batch'], kwargs['beta'], kwargs['lang'], kwargs['video'], kwargs[
                'bulk']

        if self.args.bulk_split > 0:
            inputs = {
                k: tf.concat([inputs[k], bulk[k]], axis=0)
                for k in inputs.keys()
            }  # combine them

        inputs = self.make_sequences_variable_length(inputs)

        with tf.GradientTape() as actor_tape, tf.GradientTape() as encoder_tape, tf.GradientTape() as planner_tape, tf.GradientTape() as cnn_tape, tf.GradientTape() as gripper_cnn_tape,\
                                tf.GradientTape() as img_goal_embed_tape, tf.GradientTape() as lang_goal_embed_tape:

            if self.args.gcbc:
                policy = self.step(inputs)
                loss = self.compute_loss(actions, policy, mask, seq_lens)
                gradients = actor_tape.gradient(loss,
                                                self.actor.trainable_variables)
                self.actor_optimizer.apply_gradients(
                    zip(gradients, self.actor.trainable_variables))
            else:
                enc_policy, plan_policy, encoding, plan, indices, actions, mask, seq_lens, sentence_embeddings = self.step(
                    inputs, lang_labelled_inputs, external_videos)

                act_enc_loss = record(
                    self.compute_loss(actions, enc_policy, mask, seq_lens),
                    self.metrics['train_act_with_enc_loss'])
                if self.args.discrete:
                    planner_loss = tf.nn.softmax_cross_entropy_with_logits(
                        labels=tf.stop_gradient(tf.nn.softmax(encoding, -1)),
                        logits=plan)
                    record(planner_loss,
                           self.metrics['train_discrete_planner_loss'])
                    loss = act_enc_loss + planner_loss * beta

                else:
                    act_plan_loss = record(
                        self.compute_loss(actions, plan_policy, mask,
                                          seq_lens),
                        self.metrics['train_act_with_plan_loss'])
                    reg_loss = record(
                        self.compute_regularisation_loss(plan, encoding),
                        self.metrics['train_reg_loss'])
                    loss = act_enc_loss + reg_loss * beta

                if self.args.fp16:
                    actor_gradients = self.compute_fp16_grads(
                        self.actor_optimizer, loss, actor_tape, self.actor)
                    encoder_gradients = self.compute_fp16_grads(
                        self.encoder_optimizer, loss, encoder_tape,
                        self.encoder)
                    planner_gradients = self.compute_fp16_grads(
                        self.planner_optimizer, loss, planner_tape,
                        self.planner)
                    if self.args.images:
                        cnn_gradients = self.compute_fp16_grads(
                            self.cnn_optimizer, loss, cnn_tape, self.cnn)
                        goal_to_goal_space_grads = self.compute_fp16_grads(
                            self.img_embed_to_goal_space_optimizer, loss,
                            img_goal_embed_tape, self.img_embed_to_goal_space)
                    if self.args.gripper_images:
                        gripper_cnn_gradients = self.compute_fp16_grads(
                            self.gripper_cnn_optimizer, loss, gripper_cnn_tape,
                            self.gripper_cnn)
                    if self.args.use_language: raise NotImplementedError
                else:
                    actor_gradients = actor_tape.gradient(
                        loss, self.actor.trainable_variables)
                    encoder_gradients = encoder_tape.gradient(
                        loss, self.encoder.trainable_variables)
                    planner_gradients = planner_tape.gradient(
                        loss, self.planner.trainable_variables)
                    if self.args.images:
                        cnn_gradients = cnn_tape.gradient(
                            loss, self.cnn.trainable_variables)
                        img_goal_to_goal_space_grads = img_goal_embed_tape.gradient(
                            loss,
                            self.img_embed_to_goal_space.trainable_variables)
                    if self.args.gripper_images:
                        gripper_cnn_gradients = gripper_cnn_tape.gradient(
                            loss, self.gripper_cnn.trainable_variables)
                    if self.args.use_language:
                        lang_goal_to_goal_space_grads = lang_goal_embed_tape.gradient(
                            loss,
                            self.lang_embed_to_goal_space.trainable_variables)

                #################### Calc indivual norms
                actor_norm = record(tf.linalg.global_norm(actor_gradients),
                                    self.metrics['actor_grad_norm'])
                encoder_norm = record(tf.linalg.global_norm(encoder_gradients),
                                      self.metrics['encoder_grad_norm'])
                planner_norm = record(tf.linalg.global_norm(planner_gradients),
                                      self.metrics['planner_grad_norm'])
                if self.args.images:
                    cnn_norm = record(tf.linalg.global_norm(cnn_gradients),
                                      self.metrics['cnn_grad_norm'])
                    img_goal_to_goal_space_norm = record(
                        tf.linalg.global_norm(img_goal_to_goal_space_grads),
                        self.metrics['img_embed_to_goal_space_norm'])
                if self.args.gripper_images:
                    gripper_cnn_norm = record(
                        tf.linalg.global_norm(gripper_cnn_gradients),
                        self.metrics['gripper_cnn_grad_norm'])
                if self.args.use_language:
                    lang_goal_to_goal_space_norm = record(
                        tf.linalg.global_norm(lang_goal_to_goal_space_grads),
                        self.metrics['lang_embed_to_goal_space_norm'])

                ##################### Calc global grad norm
                gradients = actor_gradients + encoder_gradients + planner_gradients
                if self.args.images:
                    gradients = gradients + cnn_gradients + img_goal_to_goal_space_grads
                if self.args.gripper_images: gradients += gripper_cnn_gradients
                if self.args.use_language:
                    gradients += lang_goal_to_goal_space_grads
                record(tf.linalg.global_norm(gradients),
                       self.metrics['global_grad_norm'])

                #################### Apply optimizer updates
                self.actor_optimizer.apply_gradients(
                    zip(actor_gradients, self.actor.trainable_variables))
                self.encoder_optimizer.apply_gradients(
                    zip(encoder_gradients, self.encoder.trainable_variables))
                if not self.args.discrete:
                    self.planner_optimizer.apply_gradients(
                        zip(planner_gradients, self.planner.trainable_variables
                            ))  # TODO TRAIN AS SECOND STAGE
                if self.args.images:
                    self.cnn_optimizer.apply_gradients(
                        zip(cnn_gradients, self.cnn.trainable_variables))
                    self.img_embed_to_goal_space_optimizer.apply_gradients(
                        zip(img_goal_to_goal_space_grads,
                            self.img_embed_to_goal_space.trainable_variables))
                if self.args.gripper_images:
                    self.gripper_cnn_optimizer.apply_gradients(
                        zip(gripper_cnn_gradients,
                            self.gripper_cnn.trainable_variables))
                if self.args.use_language:
                    self.lang_embed_to_goal_space_optimizer.apply_gradients(
                        zip(lang_goal_to_goal_space_grads,
                            self.lang_embed_to_goal_space.trainable_variables))
                ################### Fin

        return record(loss, self.metrics['train_loss'])