def test_step(self, inputs, beta): inputs = self.make_sequences_variable_length(inputs) # actions, seq_lens, mask = inputs['acts'], inputs['seq_lens'], inputs[ 'masks'] if self.args.gcbc: policy = self.step(inputs) loss = self.compute_loss(actions, policy, mask, seq_lens) log_action_breakdown(policy, actions, mask, seq_lens, self.args.num_distribs is not None, self.dl.quaternion_act, self.valid_position_loss, self.valid_max_position_loss, \ self.valid_rotation_loss, self.valid_max_rotation_loss, self.valid_gripper_loss, self.compute_MAE) else: enc_policy, plan_policy, encoding, plan = self.step(inputs) act_enc_loss = record( self.compute_loss(actions, enc_policy, mask, seq_lens), self.metrics['valid_act_with_enc_loss']) if self.args.discrete: loss = act_enc_loss log_action_breakdown(enc_policy, actions, mask, seq_lens, self.args.num_distribs is not None, self.dl.quaternion_act, self.metrics['valid_position_loss'], \ self.metrics['valid_max_position_loss'], self.metrics['valid_rotation_loss'], self.metrics['valid_max_rotation_loss'], self.metrics['valid_gripper_loss'], self.compute_MAE) else: act_plan_loss = record( self.compute_loss(actions, plan_policy, mask, seq_lens), self.metrics['valid_act_with_plan_loss']) reg_loss = record( self.compute_regularisation_loss(plan, encoding), self.metrics['valid_reg_loss']) loss = act_plan_loss + reg_loss * beta log_action_breakdown(plan_policy, actions, mask, seq_lens, self.args.num_distribs is not None, self.dl.quaternion_act, self.metrics['valid_position_loss'], \ self.metrics['valid_max_position_loss'], self.metrics['valid_rotation_loss'], self.metrics['valid_max_rotation_loss'], self.metrics['valid_gripper_loss'], self.compute_MAE) return record(loss, self.metrics['valid_loss'])
def test_step(self, inputs, beta): states, actions, goals, seq_lens, mask = inputs['obs'], inputs[ 'acts'], inputs['goals'], inputs['seq_lens'], inputs['masks'] ########################### Between here if self.args.images: imgs, proprioceptive_features, goal_imgs = inputs['imgs'], inputs[ 'proprioceptive_features'], inputs['goal_imgs'] B, T, H, W, C = imgs.shape imgs, goal_imgs = tf.reshape(imgs, [B * T, H, W, C]), tf.reshape( goal_imgs, [B * T, H, W, C]) img_embeddings, goal_embeddings = tf.reshape( self.cnn(imgs), [B, T, -1]), tf.reshape(self.cnn(goal_imgs), [B, T, -1]) states = tf.concat( [img_embeddings, proprioceptive_features], -1 ) # gets both the image and it's own xyz ori and angle as pose goals = goal_embeddings # should be B,T, embed_size if self.args.gcbc: policy = self.actor([states, goals], training=False) loss = self.compute_loss(actions, policy, mask, seq_lens) log_action_breakdown(policy, actions, mask, seq_lens, self.args.num_distribs is not None, self.dl.quaternion_act, self.valid_position_loss, self.valid_max_position_loss, \ self.valid_rotation_loss, self.valid_max_rotation_loss, self.valid_gripper_loss, self.compute_MAE) else: encoding = self.encoder([states, actions]) plan = self.planner( [states[:, 0, :], goals[:, 0, :]] ) # the final goals are tiled out over the entire non masked sequence, so the first timestep is the final goal. z_enc = encoding.sample() z_plan = plan.sample() z_enc_tiled = tf.tile(tf.expand_dims(z_enc, 1), (1, self.dl.window_size, 1)) z_plan_tiled = tf.tile(tf.expand_dims(z_plan, 1), (1, self.dl.window_size, 1)) enc_policy = self.actor([states, z_enc_tiled, goals]) plan_policy = self.actor([states, z_plan_tiled, goals]) ############### and here could be abstracted into one function act_enc_loss = record( self.compute_loss(actions, enc_policy, mask, seq_lens), self.metrics['valid_act_with_enc_loss']) act_plan_loss = record( self.compute_loss(actions, plan_policy, mask, seq_lens), self.metrics['valid_act_with_plan_loss']) reg_loss = record(self.compute_regularisation_loss(plan, encoding), self.metrics['valid_reg_loss']) loss = act_plan_loss + reg_loss * beta log_action_breakdown(plan_policy, actions, mask, seq_lens, self.args.num_distribs is not None, self.dl.quaternion_act, self.metrics['valid_position_loss'], \ self.metrics['valid_max_position_loss'], self.metrics['valid_rotation_loss'], self.metrics['valid_max_rotation_loss'], self.metrics['valid_gripper_loss'], self.compute_MAE) if self.args.gcbc: return record(loss, self.metrics['valid_loss']) else: return record(loss, self.metrics['valid_loss']), z_enc, z_plan
def test_step(self, **kwargs): inputs, beta, lang_labelled_inputs, external_videos = kwargs[ 'batch'], kwargs['beta'], kwargs['lang'], kwargs['video'] inputs = self.make_sequences_variable_length(inputs) # actions, seq_lens, mask = inputs['acts'], inputs['seq_lens'], inputs[ 'masks'] if self.args.gcbc: policy = self.step(inputs) loss = self.compute_loss(actions, policy, mask, seq_lens) log_action_breakdown(policy, actions, mask, seq_lens, self.args.num_distribs is not None, self.dl.quaternion_act, self.valid_position_loss, self.valid_max_position_loss, \ self.valid_rotation_loss, self.valid_max_rotation_loss, self.valid_gripper_loss, self.compute_MAE) else: enc_policy, plan_policy, encoding, plan, indices, actions, mask, seq_lens, sentence_embeddings = self.step( inputs, lang_labelled_inputs, external_videos) act_enc_loss = record( self.compute_loss(actions, enc_policy, mask, seq_lens), self.metrics['valid_act_with_enc_loss']) if self.args.discrete: planner_loss = tf.nn.softmax_cross_entropy_with_logits( labels=tf.stop_gradient(tf.nn.softmax(encoding, -1)), logits=plan) record(planner_loss, self.metrics['valid_discrete_planner_loss']) loss = act_enc_loss + planner_loss * beta else: act_plan_loss = record( self.compute_loss(actions, plan_policy, mask, seq_lens), self.metrics['valid_act_with_plan_loss']) reg_loss = record( self.compute_regularisation_loss(plan, encoding), self.metrics['valid_reg_loss']) loss = act_plan_loss + reg_loss * beta log_action_breakdown(plan_policy, actions, mask, seq_lens, self.args.num_distribs is not None, self.dl.quaternion_act, self.metrics['valid_position_loss'], \ self.metrics['valid_max_position_loss'], self.metrics['valid_rotation_loss'], self.metrics['valid_max_rotation_loss'], self.metrics['valid_gripper_loss'], self.compute_MAE) log_action_breakdown(enc_policy, actions, mask, seq_lens, self.args.num_distribs is not None, self.dl.quaternion_act, self.metrics['valid_enc_position_loss'], \ self.metrics['valid_enc_max_position_loss'], self.metrics['valid_enc_rotation_loss'], self.metrics['valid_enc_max_rotation_loss'], self.metrics['valid_enc_gripper_loss'], self.compute_MAE) if self.args.use_language: # setting probabilistic = false and just passing in the .sample() of the distrib as for some reason slicing it auto samples? log_action_breakdown(plan_policy.sample()[indices['unlabelled']:], actions[indices['unlabelled']:], mask[indices['unlabelled']:], seq_lens[indices['unlabelled']:], False, self.dl.quaternion_act, self.metrics['valid_lang_position_loss'], self.metrics['valid_lang_max_position_loss'], self.metrics['valid_lang_rotation_loss'], self.metrics['valid_lang_max_rotation_loss'], \ self.metrics['valid_lang_gripper_loss'], self.compute_MAE) return record(loss, self.metrics['valid_loss'])
def train_step(self, inputs, beta): inputs = self.make_sequences_variable_length(inputs) with tf.GradientTape() as actor_tape, tf.GradientTape( ) as encoder_tape, tf.GradientTape() as planner_tape: actions, seq_lens, mask = inputs['acts'], inputs[ 'seq_lens'], inputs['masks'] if self.args.gcbc: policy = self.step(inputs) loss = self.compute_loss(actions, policy, mask, seq_lens) gradients = actor_tape.gradient(loss, self.actor.trainable_variables) self.actor_optimizer.apply_gradients( zip(gradients, self.actor.trainable_variables)) else: enc_policy, plan_policy, encoding, plan = self.step(inputs) act_enc_loss = record( self.compute_loss(actions, enc_policy, mask, seq_lens), self.metrics['train_act_with_enc_loss']) act_plan_loss = record( self.compute_loss(actions, plan_policy, mask, seq_lens), self.metrics['train_act_with_plan_loss']) reg_loss = record( self.compute_regularisation_loss(plan, encoding), self.metrics['train_reg_loss']) loss = act_enc_loss + reg_loss * beta if self.args.fp16: actor_gradients = self.compute_fp16_grads( self.actor_optimizer, loss, actor_tape, self.actor) encoder_gradients = self.compute_fp16_grads( self.encoder_optimizer, loss, encoder_tape, self.encoder) planner_gradients = self.compute_fp16_grads( self.planner_optimizer, loss, planner_tape, self.planner) else: actor_gradients = actor_tape.gradient( loss, self.actor.trainable_variables) encoder_gradients = encoder_tape.gradient( loss, self.encoder.trainable_variables) planner_gradients = planner_tape.gradient( loss, self.planner.trainable_variables) actor_norm = record(tf.linalg.global_norm(actor_gradients), self.metrics['actor_grad_norm']) encoder_norm = record(tf.linalg.global_norm(encoder_gradients), self.metrics['encoder_grad_norm']) planner_norm = record(tf.linalg.global_norm(planner_gradients), self.metrics['planner_grad_norm']) gradients = actor_gradients + encoder_gradients + planner_gradients record(tf.linalg.global_norm(gradients), self.metrics['global_grad_norm']) self.actor_optimizer.apply_gradients( zip(actor_gradients, self.actor.trainable_variables)) self.encoder_optimizer.apply_gradients( zip(encoder_gradients, self.encoder.trainable_variables)) self.planner_optimizer.apply_gradients( zip(planner_gradients, self.planner.trainable_variables)) return record(loss, self.metrics['train_loss'])
def train_step(self, inputs, beta): with tf.GradientTape() as actor_tape, tf.GradientTape( ) as encoder_tape, tf.GradientTape() as planner_tape: # Todo: figure out mask and seq_lens for new dataset states, actions, goals, seq_lens, mask = inputs['obs'], inputs[ 'acts'], inputs['goals'], inputs['seq_lens'], inputs['masks'] # Ok, what steps do we need to take # 1. When using imagesChange the definition of obs_dim to feature encoder dim + proprioceptive features # 2. Reshape imgs to B*T H W C. # 3. Sub in for states and goals. # 4. THen there should be no further changes! if self.args.images: imgs, proprioceptive_features, goal_imgs = inputs[ 'imgs'], inputs['proprioceptive_features'], inputs[ 'goal_imgs'] B, T, H, W, C = imgs.shape imgs, goal_imgs = tf.reshape(imgs, [B * T, H, W, C]), tf.reshape( goal_imgs, [B * T, H, W, C]) img_embeddings, goal_embeddings = tf.reshape( self.cnn(imgs), [B, T, -1]), tf.reshape(self.cnn(goal_imgs), [B, T, -1]) states = tf.concat( [img_embeddings, proprioceptive_features], -1 ) # gets both the image and it's own xyz ori and angle as pose goals = goal_embeddings # should be B,T, embed_size if self.args.gcbc: distrib = self.actor([states, goals]) loss = self.compute_loss(actions, distrib, mask, seq_lens) gradients = actor_tape.gradient(loss, self.actor.trainable_variables) self.optimizer.apply_gradients( zip(gradients, self.actor.trainable_variables)) else: encoding = self.encoder([states, actions]) plan = self.planner( [states[:, 0, :], goals[:, 0, :]] ) # the final goals are tiled out over the entire non masked sequence, so the first timestep is the final goal. z_enc = encoding.sample() z_plan = plan.sample() z_enc_tiled = tf.tile(tf.expand_dims(z_enc, 1), (1, self.dl.window_size, 1)) z_plan_tiled = tf.tile(tf.expand_dims(z_plan, 1), (1, self.dl.window_size, 1)) enc_policy = self.actor([states, z_enc_tiled, goals]) plan_policy = self.actor([states, z_plan_tiled, goals]) act_enc_loss = record( self.compute_loss(actions, enc_policy, mask, seq_lens), self.metrics['train_act_with_enc_loss']) act_plan_loss = record( self.compute_loss(actions, plan_policy, mask, seq_lens), self.metrics['train_act_with_plan_loss']) reg_loss = record( self.compute_regularisation_loss(plan, encoding), self.metrics['train_reg_loss']) loss = act_enc_loss + reg_loss * beta actor_gradients = actor_tape.gradient( loss, self.actor.trainable_variables) encoder_gradients = encoder_tape.gradient( loss, self.encoder.trainable_variables) planner_gradients = planner_tape.gradient( loss, self.planner.trainable_variables) actor_norm = record(tf.linalg.global_norm(actor_gradients), self.metrics['actor_grad_norm']) encoder_norm = record(tf.linalg.global_norm(encoder_gradients), self.metrics['encoder_grad_norm']) planner_norm = record(tf.linalg.global_norm(planner_gradients), self.metrics['planner_grad_norm']) gradients = actor_gradients + encoder_gradients + planner_gradients record(tf.linalg.global_norm(gradients), self.metrics['global_grad_norm']) self.actor_optimizer.apply_gradients( zip(actor_gradients, self.actor.trainable_variables)) self.encoder_optimizer.apply_gradients( zip(encoder_gradients, self.encoder.trainable_variables)) self.planner_optimizer.apply_gradients( zip(planner_gradients, self.planner.trainable_variables)) return record(loss, self.metrics['train_loss'])
def train_step(self, **kwargs): inputs, beta, lang_labelled_inputs, external_videos, bulk = kwargs[ 'batch'], kwargs['beta'], kwargs['lang'], kwargs['video'], kwargs[ 'bulk'] if self.args.bulk_split > 0: inputs = { k: tf.concat([inputs[k], bulk[k]], axis=0) for k in inputs.keys() } # combine them inputs = self.make_sequences_variable_length(inputs) with tf.GradientTape() as actor_tape, tf.GradientTape() as encoder_tape, tf.GradientTape() as planner_tape, tf.GradientTape() as cnn_tape, tf.GradientTape() as gripper_cnn_tape,\ tf.GradientTape() as img_goal_embed_tape, tf.GradientTape() as lang_goal_embed_tape: if self.args.gcbc: policy = self.step(inputs) loss = self.compute_loss(actions, policy, mask, seq_lens) gradients = actor_tape.gradient(loss, self.actor.trainable_variables) self.actor_optimizer.apply_gradients( zip(gradients, self.actor.trainable_variables)) else: enc_policy, plan_policy, encoding, plan, indices, actions, mask, seq_lens, sentence_embeddings = self.step( inputs, lang_labelled_inputs, external_videos) act_enc_loss = record( self.compute_loss(actions, enc_policy, mask, seq_lens), self.metrics['train_act_with_enc_loss']) if self.args.discrete: planner_loss = tf.nn.softmax_cross_entropy_with_logits( labels=tf.stop_gradient(tf.nn.softmax(encoding, -1)), logits=plan) record(planner_loss, self.metrics['train_discrete_planner_loss']) loss = act_enc_loss + planner_loss * beta else: act_plan_loss = record( self.compute_loss(actions, plan_policy, mask, seq_lens), self.metrics['train_act_with_plan_loss']) reg_loss = record( self.compute_regularisation_loss(plan, encoding), self.metrics['train_reg_loss']) loss = act_enc_loss + reg_loss * beta if self.args.fp16: actor_gradients = self.compute_fp16_grads( self.actor_optimizer, loss, actor_tape, self.actor) encoder_gradients = self.compute_fp16_grads( self.encoder_optimizer, loss, encoder_tape, self.encoder) planner_gradients = self.compute_fp16_grads( self.planner_optimizer, loss, planner_tape, self.planner) if self.args.images: cnn_gradients = self.compute_fp16_grads( self.cnn_optimizer, loss, cnn_tape, self.cnn) goal_to_goal_space_grads = self.compute_fp16_grads( self.img_embed_to_goal_space_optimizer, loss, img_goal_embed_tape, self.img_embed_to_goal_space) if self.args.gripper_images: gripper_cnn_gradients = self.compute_fp16_grads( self.gripper_cnn_optimizer, loss, gripper_cnn_tape, self.gripper_cnn) if self.args.use_language: raise NotImplementedError else: actor_gradients = actor_tape.gradient( loss, self.actor.trainable_variables) encoder_gradients = encoder_tape.gradient( loss, self.encoder.trainable_variables) planner_gradients = planner_tape.gradient( loss, self.planner.trainable_variables) if self.args.images: cnn_gradients = cnn_tape.gradient( loss, self.cnn.trainable_variables) img_goal_to_goal_space_grads = img_goal_embed_tape.gradient( loss, self.img_embed_to_goal_space.trainable_variables) if self.args.gripper_images: gripper_cnn_gradients = gripper_cnn_tape.gradient( loss, self.gripper_cnn.trainable_variables) if self.args.use_language: lang_goal_to_goal_space_grads = lang_goal_embed_tape.gradient( loss, self.lang_embed_to_goal_space.trainable_variables) #################### Calc indivual norms actor_norm = record(tf.linalg.global_norm(actor_gradients), self.metrics['actor_grad_norm']) encoder_norm = record(tf.linalg.global_norm(encoder_gradients), self.metrics['encoder_grad_norm']) planner_norm = record(tf.linalg.global_norm(planner_gradients), self.metrics['planner_grad_norm']) if self.args.images: cnn_norm = record(tf.linalg.global_norm(cnn_gradients), self.metrics['cnn_grad_norm']) img_goal_to_goal_space_norm = record( tf.linalg.global_norm(img_goal_to_goal_space_grads), self.metrics['img_embed_to_goal_space_norm']) if self.args.gripper_images: gripper_cnn_norm = record( tf.linalg.global_norm(gripper_cnn_gradients), self.metrics['gripper_cnn_grad_norm']) if self.args.use_language: lang_goal_to_goal_space_norm = record( tf.linalg.global_norm(lang_goal_to_goal_space_grads), self.metrics['lang_embed_to_goal_space_norm']) ##################### Calc global grad norm gradients = actor_gradients + encoder_gradients + planner_gradients if self.args.images: gradients = gradients + cnn_gradients + img_goal_to_goal_space_grads if self.args.gripper_images: gradients += gripper_cnn_gradients if self.args.use_language: gradients += lang_goal_to_goal_space_grads record(tf.linalg.global_norm(gradients), self.metrics['global_grad_norm']) #################### Apply optimizer updates self.actor_optimizer.apply_gradients( zip(actor_gradients, self.actor.trainable_variables)) self.encoder_optimizer.apply_gradients( zip(encoder_gradients, self.encoder.trainable_variables)) if not self.args.discrete: self.planner_optimizer.apply_gradients( zip(planner_gradients, self.planner.trainable_variables )) # TODO TRAIN AS SECOND STAGE if self.args.images: self.cnn_optimizer.apply_gradients( zip(cnn_gradients, self.cnn.trainable_variables)) self.img_embed_to_goal_space_optimizer.apply_gradients( zip(img_goal_to_goal_space_grads, self.img_embed_to_goal_space.trainable_variables)) if self.args.gripper_images: self.gripper_cnn_optimizer.apply_gradients( zip(gripper_cnn_gradients, self.gripper_cnn.trainable_variables)) if self.args.use_language: self.lang_embed_to_goal_space_optimizer.apply_gradients( zip(lang_goal_to_goal_space_grads, self.lang_embed_to_goal_space.trainable_variables)) ################### Fin return record(loss, self.metrics['train_loss'])