def __init__(self, robot, n_iterations, parameters, logs="logs/log.npz"):
        super(SalamanderCMC, self).__init__()
        self.robot = robot
        timestep = int(robot.getBasicTimeStep())
        self.network = SalamanderNetwork(1e-3*timestep, parameters)

        # Position sensors
        self.position_sensors = [
            self.robot.getPositionSensor('position_sensor_{}'.format(i+1))
            for i in range(self.N_BODY_JOINTS)
        ] + [
            self.robot.getPositionSensor('position_sensor_leg_{}'.format(i+1))
            for i in range(self.N_LEGS)
        ]
        for sensor in self.position_sensors:
            sensor.enable(timestep)

        # GPS
        self.enable = False
        self.gps = robot.getGPS("fgirdle_gps")
        self.gps.enable(timestep)

        # Get motors
        self.motors_body = [
            self.robot.getMotor("motor_{}".format(i+1))
            for i in range(self.N_BODY_JOINTS)
        ]
        self.motors_legs = [
            self.robot.getMotor("motor_leg_{}".format(i+1))
            for i in range(self.N_LEGS)
        ]

        # Set motors
        for motor in self.motors_body:
            motor.setPosition(0)
            motor.enableForceFeedback(timestep)
            motor.enableTorqueFeedback(timestep)
        for motor in self.motors_legs:
            motor.setPosition(-np.pi/2)
            motor.enableForceFeedback(timestep)
            motor.enableTorqueFeedback(timestep)

        # Iteration counter
        self.iteration = 0

        # Logging
        self.log = ExperimentLogger(
            n_iterations,
            n_links=1,
            n_joints=self.N_BODY_JOINTS+self.N_LEGS,
            filename=logs,
            timestep=1e-3*timestep,
            **parameters
        )
Ejemplo n.º 2
0
def run(options):
    logger = get_logger()
    experiment_logger = ExperimentLogger()

    train_dataset, validation_dataset = get_train_and_validation(options)
    train_iterator = get_train_iterator(options, train_dataset)
    validation_iterator = get_validation_iterator(options, validation_dataset)
    embeddings = train_dataset['embeddings']
    word2idx = train_dataset['word2idx']

    logger.info('Initializing model.')
    trainer = build_net(options, embeddings, train_iterator, word2idx=word2idx)
    logger.info('Model:')
    for name, p in trainer.net.named_parameters():
        logger.info('{} {} {}'.format(name, p.shape, p.requires_grad))

    # Evaluation.
    context = {}
    context['dataset'] = validation_dataset
    context['batch_iterator'] = validation_iterator
    model_evaluation = ModelEvaluation(get_eval_components(options, context, config_lst=options.eval_config))


    if options.eval_only_mode:
        info = dict()
        info['experiment_path'] = options.experiment_path
        info['step'] = 0
        run_evaluation(options, trainer, model_evaluation, info, metadata=dict(step=0))
        sys.exit()

    if options.save_init:
        logger.info('Saving model (init).')
        trainer.save_model(os.path.join(options.experiment_path, 'model_init.pt'))

    run_train(options, train_iterator, trainer, model_evaluation)
Ejemplo n.º 3
0
def run_train(options, train_iterator, trainer, model_evaluation):
    logger = get_logger()
    experiment_logger = ExperimentLogger()

    logger.info('Running train.')

    step = 0
    best_step = 0
    best_metric = math.inf
    best_dict = {}

    my_iterator = MyTrainIterator(options, train_iterator)

    for epoch, step, batch_idx, batch_map, should in my_iterator.get_iterator():
        # HACK: Weird fix that counteracts other libraries (i.e. allennlp) modifying
        # the global logger. Needed after adding tensorboard.
        while len(logger.parent.handlers) > 0:
            logger.parent.handlers.pop()

        batch_map['step'] = step
        result = trainer.step(batch_map)
        experiment_logger.record(result)
        del result

        if should['log']:
            metrics = experiment_logger.log_batch(epoch, step, batch_idx)

        # -- Periodic Checkpoints -- #

        if should['periodic']:
            logger.info('Saving model (periodic).')
            trainer.save_model(os.path.join(options.experiment_path, 'model_periodic.pt'))
            save_experiment(os.path.join(options.experiment_path, 'experiment_periodic.json'),
                            dict(step=step, epoch=epoch, best_step=best_step, best_metric=best_metric))

        if should['distinct']:
            logger.info('Saving model (distinct).')
            trainer.save_model(os.path.join(options.experiment_path, 'model.step_{}.pt'.format(step)))
            save_experiment(os.path.join(options.experiment_path, 'experiment.step_{}.json'.format(step)),
                            dict(step=step, epoch=epoch, best_step=best_step, best_metric=best_metric))

        # -- Validation -- #

        if should['eval']:
            logger.info('Evaluation.')

            info = dict()
            info['experiment_path'] = options.experiment_path
            info['step'] = step
            info['epoch'] = epoch

            for eval_result_dict in run_evaluation(options, trainer, model_evaluation, info, metadata=dict(step=step)):
                result = eval_result_dict['result']
                func = eval_result_dict['component']
                name = func.name

                for key, val, is_best in func.compare(best_dict, result):
                    best_dict_key = 'best__{}__{}'.format(name, key)

                    # Early stopping.
                    if is_best:
                        if best_dict_key in best_dict:
                            prev_val = best_dict[best_dict_key]['value']
                        else:
                            prev_val = None
                        # Update running result.
                        best_dict[best_dict_key] = {}
                        best_dict[best_dict_key]['eval'] = name
                        best_dict[best_dict_key]['metric'] = key
                        best_dict[best_dict_key]['value'] = val
                        best_dict[best_dict_key]['step'] = step
                        best_dict[best_dict_key]['epoch'] = epoch

                        logger.info('Recording, best eval, key = {}, val = {} -> {}, json = {}'.format(
                            best_dict_key, prev_val, val, json.dumps(best_dict[best_dict_key])))

                        if step >= options.save_after:
                            logger.info('Saving model, best eval, key = {}, json = {}'.format(
                                best_dict_key, json.dumps(best_dict[best_dict_key])))
                            logger.info('checkpoint_dir = {}'.format(options.experiment_path))

                            # Save result and model.
                            trainer.save_model(
                                os.path.join(options.experiment_path, 'model.{}.pt'.format(best_dict_key)))
                            save_experiment(os.path.join(options.experiment_path, 'experiment.{}.json'.format(best_dict_key)),
                                best_dict[best_dict_key])

        # END OF EPOCH

        if should['end_of_epoch']:
            experiment_logger.log_epoch(epoch, step)
            trainer.end_of_epoch(best_dict)
Ejemplo n.º 4
0
class SalamanderCMC(object):
    """Salamander robot for CMC"""

    N_BODY_JOINTS = 10
    N_LEGS = 4

    def __init__(self, robot, n_iterations, parameters, logs="logs/log.npz"):
        super(SalamanderCMC, self).__init__()
        self.robot = robot
        timestep = int(robot.getBasicTimeStep())
        self.network = SalamanderNetwork(1e-3*timestep, parameters)

        # Position sensors
        self.position_sensors = [
            self.robot.getPositionSensor('position_sensor_{}'.format(i+1))
            for i in range(self.N_BODY_JOINTS)
        ] + [
            self.robot.getPositionSensor('position_sensor_leg_{}'.format(i+1))
            for i in range(self.N_LEGS)
        ]
        for sensor in self.position_sensors:
            sensor.enable(timestep)

        # GPS
        self.gps = robot.getGPS("fgirdle_gps")
        self.gps.enable(timestep)

        # Get motors
        self.motors_body = [
            self.robot.getMotor("motor_{}".format(i+1))
            for i in range(self.N_BODY_JOINTS)
        ]
        self.motors_legs = [
            self.robot.getMotor("motor_leg_{}".format(i+1))
            for i in range(self.N_LEGS)
        ]

        # Set motors
        for motor in self.motors_body:
            motor.setPosition(0)
            motor.enableForceFeedback(timestep)
            motor.enableTorqueFeedback(timestep)
        for motor in self.motors_legs:
            motor.setPosition(-np.pi/2)
            motor.enableForceFeedback(timestep)
            motor.enableTorqueFeedback(timestep)

        # Iteration counter
        self.iteration = 0

        # Logging
        self.log = ExperimentLogger(
            n_iterations,
            n_links=1,
            n_joints=self.N_BODY_JOINTS+self.N_LEGS,
            filename=logs,
            timestep=1e-3*timestep,
            **parameters
        )
        
        #GPS stuff
        
        self.waterPosx = 0
        self.NetworkParameters = self.network.parameters
        self.SimulationParameters = parameters
        self.doTransition = False

        self.keyboard = Keyboard()
        self.keyboard.enable(samplingPeriod=100)
        self.lastkey = 0
        
        

    def log_iteration(self):
        """Log state"""
        self.log.log_link_positions(self.iteration, 0, self.gps.getValues())
        for i, motor in enumerate(self.motors_body):
            # Position
            self.log.log_joint_position(
                self.iteration, i,
                self.position_sensors[i].getValue()
            )
            # # Velocity
            # self.log.log_joint_velocity(
            #     self.iteration, i,
            #     motor.getVelocity()
            # )
            # Command
            self.log.log_joint_cmd(
                self.iteration, i,
                motor.getTargetPosition()
            )
            # Torque
            self.log.log_joint_torque(
                self.iteration, i,
                motor.getTorqueFeedback()
            )
            # Torque feedback
            self.log.log_joint_torque_feedback(
                self.iteration, i,
                motor.getTorqueFeedback()
            )
        for i, motor in enumerate(self.motors_legs):
            # Position
            self.log.log_joint_position(
                self.iteration, 10+i,
                self.position_sensors[10+i].getValue()
            )
            # # Velocity
            # self.log.log_joint_velocity(
            #     self.iteration, i,
            #     motor.getVelocity()
            # )
            # Command
            self.log.log_joint_cmd(
                self.iteration, 10+i,
                motor.getTargetPosition()
            )
            # Torque
            self.log.log_joint_torque(
                self.iteration, 10+i,
                motor.getTorqueFeedback()
            )
            # Torque feedback
            self.log.log_joint_torque_feedback(
                self.iteration, 10+i,
                motor.getTorqueFeedback()
            )

    def step(self):
        """Step"""
        # Increment iteration
        self.iteration += 1

        # Update network
        self.network.step()
        positions = self.network.get_motor_position_output()

        # Update control
        for i in range(self.N_BODY_JOINTS):
            self.motors_body[i].setPosition(positions[i])
        for i in range(self.N_LEGS):
            self.motors_legs[i].setPosition(
                positions[self.N_BODY_JOINTS+i] - np.pi/2
            )
        
        
        key=self.keyboard.getKey()
        if (key==ord('A') and key is not self.lastkey):
            print('Turning left')
            self.SimulationParameters.turnRate = [0.5,1]
            self.NetworkParameters.set_nominal_amplitudes(self.SimulationParameters)
            self.lastkey = key
        if (key==ord('D') and key is not self.lastkey):
            print('Turning right')
            self.SimulationParameters.turnRate = [1,0.5]
            self.NetworkParameters.set_nominal_amplitudes(self.SimulationParameters)
            self.lastkey = key
        if (key==ord('W') and key is not self.lastkey):
            print('Going forward')
            self.SimulationParameters.turnRate = [1,1]
            self.NetworkParameters.set_nominal_amplitudes(self.SimulationParameters)
            self.SimulationParameters.Backwards = False
            self.NetworkParameters.set_phase_bias(self.SimulationParameters)
            self.lastkey = key
            
        if (key==ord('S') and key is not self.lastkey):
            print('Going backward')
            self.SimulationParameters.Backwards = True
            self.NetworkParameters.set_phase_bias(self.SimulationParameters)
            self.SimulationParameters.turnRate = [1,1]
            self.NetworkParameters.set_nominal_amplitudes(self.SimulationParameters)
            self.lastkey = key
            
        if (key==ord('T') and key is not self.lastkey):
            if self.doTransition:
                print('Disabling transition')
                self.doTransition = False
            else:
                print('Enabling transition')
                self.doTransition = True
            self.lastkey = key
        
        if self.doTransition:
            gpsPos = self.gps.getValues()
            
            if gpsPos[0] < self.waterPosx+2 and gpsPos[0] > self.waterPosx -0.5:
                gain = 4/2.5*(gpsPos[0]+0.5) + 1
                self.SimulationParameters.drive = gain
                #print('Transitioning')
                
                self.NetworkParameters.set_nominal_amplitudes(self.SimulationParameters)
                self.NetworkParameters.set_frequencies(self.SimulationParameters)

        # Log data
        self.log_iteration()
Ejemplo n.º 5
0
def run_experiment():
    args = parse_args()
    print(args)

    runtime_params_keys = [
        'epochs', 'log_interval', 'log_dir', 'save_model', 'runs'
    ]
    runtime_params = {
        k: v
        for k, v in args.items() if k in runtime_params_keys
    }
    print('Runtime params', runtime_params)

    if args['use_cuda']:
        if not check_cuda_availability(args['model_library']):
            raise ValueError("CUDA was requested but CUDA is not available")

    bug_evaluation_name = '{}_{}'.format(args['bug_name'],
                                         args['evaluation_type'])
    run_identifier = EvaluationRunIdentifier(
        name=args['bug_name'],
        evaluation_type=args['evaluation_type'],
        challenge=args['challenge'],
        lib_name=args['model_library'],
        model_name=args['model_name'])
    run_identifier_name = EvaluationRunIdentifier.run_identifier(
        run_identifier)
    logger = ExperimentLogger(run_identifier_name, **args)

    # Get server connection
    socket, context = connect_server(args['data_server_endpoint'])

    # Request server for seed
    seed = server_interactions.request_seed(socket, run_identifier)
    logger.status('Using seed value {}'.format(seed))

    for run in range(args['resume_run_at'] or 0, args['runs']):
        current_seed = seed[run]
        logger.current_run = run
        # Local seed is indexed at the run
        set_local_seed(current_seed)
        model_creation_args = {
            'use_gpu': args['use_cuda'],
            'num_classes': args['num_classes']
        }
        # Recreate the net for each run with new initial weights
        model = ModelStore.get_model_for_name(library=args['model_library'],
                                              name=args['model_name'],
                                              **model_creation_args)
        model.initialize_weights(current_seed)

        data_params = model.get_data_params()
        logger.status('Requesting data from server')
        train_data, test_data = server_interactions.prepare_data_for_run(
            socket, run_identifier, run, current_seed, data_params)
        log_data_received(logger=logger,
                          seed=current_seed,
                          train_data=train_data,
                          test_data=test_data)
        logger.status('Received data from server')

        # TODO: Turn back on if necessary
        log_params(model, logger)

        model.start_training()
        for epoch in range(1, args['epochs'] + 1):
            train(model, train_data, epoch, logger, **runtime_params)
            np_pred, np_target = test(model, test_data, logger)

        # TODO (opt): Put an option if we want per epoch or per run stats
        metrics = create_metrics_dto(predictions=np_pred, target=np_target)
        logger.metrics(metrics_dto_str(metrics))
        logger.status('Sending metrics to server')
        server_interactions.send_metrics_for_run(socket, run_identifier, seed,
                                                 run, metrics)

        if (args['save_model']):
            model.save(evaluation_type=args['evaluation_type'], run=run)
        log_params(model, logger)
Ejemplo n.º 6
0
class SalamanderCMC(object):
    """Salamander robot for CMC"""

    N_BODY_JOINTS = 10
    N_LEGS = 4

    def __init__(self, robot, n_iterations, parameters, logs="logs/log.npz"):
        super(SalamanderCMC, self).__init__()
        self.robot = robot
        timestep = int(robot.getBasicTimeStep())
        self.network = SalamanderNetwork(1e-3 * timestep, parameters)

        # Position sensors
        self.position_sensors = [
            self.robot.getPositionSensor('position_sensor_{}'.format(i + 1))
            for i in range(self.N_BODY_JOINTS)
        ] + [
            self.robot.getPositionSensor(
                'position_sensor_leg_{}'.format(i + 1))
            for i in range(self.N_LEGS)
        ]
        for sensor in self.position_sensors:
            sensor.enable(timestep)

        # GPS
        self.gps = robot.getGPS("fgirdle_gps")
        self.gps.enable(timestep)

        # Get motors
        self.motors_body = [
            self.robot.getMotor("motor_{}".format(i + 1))
            for i in range(self.N_BODY_JOINTS)
        ]
        self.motors_legs = [
            self.robot.getMotor("motor_leg_{}".format(i + 1))
            for i in range(self.N_LEGS)
        ]

        # Set motors
        for motor in self.motors_body:
            motor.setPosition(0)
            motor.enableForceFeedback(timestep)
            motor.enableTorqueFeedback(timestep)
        for motor in self.motors_legs:
            motor.setPosition(-np.pi / 2)
            motor.enableForceFeedback(timestep)
            motor.enableTorqueFeedback(timestep)

        # Iteration counter
        self.iteration = 0

        # Logging
        self.log = ExperimentLogger(n_iterations,
                                    n_links=1,
                                    n_joints=self.N_BODY_JOINTS + self.N_LEGS,
                                    filename=logs,
                                    timestep=1e-3 * timestep,
                                    **parameters)

    def log_iteration(self):
        """Log state"""
        self.log.log_link_positions(self.iteration, 0, self.gps.getValues())
        for i, motor in enumerate(self.motors_body):
            # Position
            self.log.log_joint_position(self.iteration, i,
                                        self.position_sensors[i].getValue())
            # Command
            self.log.log_joint_cmd(self.iteration, i,
                                   motor.getTargetPosition())
            # Torque
            self.log.log_joint_torque(self.iteration, i,
                                      motor.getTorqueFeedback())
            # Torque feedback
            self.log.log_joint_torque_feedback(self.iteration, i,
                                               motor.getTorqueFeedback())
        for i, motor in enumerate(self.motors_legs):
            # Position
            self.log.log_joint_position(
                self.iteration, 10 + i,
                self.position_sensors[10 + i].getValue())
            # Command
            self.log.log_joint_cmd(self.iteration, 10 + i,
                                   motor.getTargetPosition())
            # Torque
            self.log.log_joint_torque(self.iteration, 10 + i,
                                      motor.getTorqueFeedback())
            # Torque feedback
            self.log.log_joint_torque_feedback(self.iteration, 10 + i,
                                               motor.getTorqueFeedback())

    def step(self):
        """Step"""
        # Increment iteration
        self.iteration += 1

        # Update network
        self.network.step()
        positions = self.network.get_motor_position_output()

        # Update control
        for i in range(self.N_BODY_JOINTS):
            self.motors_body[i].setPosition(positions[i])
        for i in range(self.N_LEGS):
            self.motors_legs[i].setPosition(positions[self.N_BODY_JOINTS + i] -
                                            np.pi / 2)

        # Log data
        self.log_iteration()
Ejemplo n.º 7
0
class SalamanderCMC(object):
    """Salamander robot for CMC"""

    N_BODY_JOINTS = 10
    N_LEGS = 4

    def __init__(self, robot, n_iterations, parameters, logs="logs/log.npz"):
        super(SalamanderCMC, self).__init__()
        self.robot = robot
        timestep = int(robot.getBasicTimeStep())
        self.network = SalamanderNetwork(1e-3 * timestep, parameters)

        # Position sensors
        self.position_sensors = [
            self.robot.getPositionSensor('position_sensor_{}'.format(i + 1))
            for i in range(self.N_BODY_JOINTS)
        ] + [
            self.robot.getPositionSensor(
                'position_sensor_leg_{}'.format(i + 1))
            for i in range(self.N_LEGS)
        ]
        for sensor in self.position_sensors:
            sensor.enable(timestep)

        # GPS
        self.gps = robot.getGPS("fgirdle_gps")
        self.gps.enable(timestep)

        # Get motors
        self.motors_body = [
            self.robot.getMotor("motor_{}".format(i + 1))
            for i in range(self.N_BODY_JOINTS)
        ]
        self.motors_legs = [
            self.robot.getMotor("motor_leg_{}".format(i + 1))
            for i in range(self.N_LEGS)
        ]

        # Set motors
        for motor in self.motors_body:
            motor.setPosition(0)
            motor.enableForceFeedback(timestep)
            motor.enableTorqueFeedback(timestep)
        for motor in self.motors_legs:
            motor.setPosition(-np.pi / 2)
            motor.enableForceFeedback(timestep)
            motor.enableTorqueFeedback(timestep)

        # Iteration counter
        self.iteration = 0

        # Logging
        self.log = ExperimentLogger(n_iterations,
                                    n_links=1,
                                    n_joints=self.N_BODY_JOINTS + self.N_LEGS,
                                    filename=logs,
                                    timestep=1e-3 * timestep,
                                    **parameters)

    def log_iteration(self):
        """Log state"""
        self.log.log_link_positions(self.iteration, 0, self.gps.getValues())
        for i, motor in enumerate(self.motors_body):
            # Position
            self.log.log_joint_position(self.iteration, i,
                                        self.position_sensors[i].getValue())
            # Command
            self.log.log_joint_cmd(self.iteration, i,
                                   motor.getTargetPosition())
            # Torque
            self.log.log_joint_torque(self.iteration, i,
                                      motor.getTorqueFeedback())
            # Torque feedback
            self.log.log_joint_torque_feedback(self.iteration, i,
                                               motor.getTorqueFeedback())
        for i, motor in enumerate(self.motors_legs):
            # Position
            self.log.log_joint_position(
                self.iteration, 10 + i,
                self.position_sensors[10 + i].getValue())
            # Command
            self.log.log_joint_cmd(self.iteration, 10 + i,
                                   motor.getTargetPosition())
            # Torque
            self.log.log_joint_torque(self.iteration, 10 + i,
                                      motor.getTorqueFeedback())
            # Torque feedback
            self.log.log_joint_torque_feedback(self.iteration, 10 + i,
                                               motor.getTorqueFeedback())

    def step(self):
        """Step"""
        # Increment iteration
        self.iteration += 1
        #print(self.gps.getValues()[0])

        if self.gps.getValues()[0] > 0.:
            self.network.parameters.drive_mlr = 3.5
            freqamp = computedrive.computefreqamp(
                self.network.parameters.drive_mlr, False)
            self.network.parameters.freqs = freqamp[0]
            self.network.parameters.nominal_amplitudes = freqamp[1]
            self.network.parameters.set_frequencies(self.network.parameters)
            self.network.parameters.set_nominal_amplitudes(
                self.network.parameters)

            #frequency = (np.ones((self.network.parameters.n_oscillators,1))*vfreq)[:,0] #2Hz

            #self.network.parameters.freqs = np.ones(self.network.parameters.n_oscillators)*frequency

            #nominal_amplitude = np.ones(self.network.parameters.n_body_joints*2) *amplitude_value
            #self.network.parameters.nominal_amplitude = np.concatenate((nominal_amplitude, np.zeros(4)), axis=None)

            #self.network.parameters.set_frequencies(self.network.parameters)
            #self.network.parameters.set_nominal_amplitudes(self.network.parameters)

        #elif self.gps.getValues()[0]<0.4:
        #drive=1
        #freqamp=computedrive.computefreqamp(drive, False)

        #self.network.parameters.freqs=freqamp[0]

        #self.network.parameters.nominal_amplitudes=freqamp[1]

        #self.network.parameters.set_frequencies(self.network.parameters)
        #self.network.parameters.set_nominal_amplitudes(self.network.parameters)

        #self.network.parameters.update(self.network.parameters)

        #drive=1
        #amplitude_value=0.15+(0.25/4)*(drive-1)
        #vfreq=1+(0.5)*(drive-1)

        #frequency = (np.ones((self.network.parameters.n_oscillators,1))*vfreq)[:,0] #2Hz

        #self.network.parameters.freqs = np.ones(self.network.parameters.n_oscillators)*frequency

        #nominal_amplitude = np.ones(self.network.parameters.n_body_joints*2) *amplitude_value
        #self.network.parameters.nominal_amplitude = np.concatenate((nominal_amplitude, np.ones(4)*self.network.parameters.amplitude_leg_nominal), axis=None)

        # Update network

        self.network.step()
        positions = self.network.get_motor_position_output()

        # Update control
        for i in range(self.N_BODY_JOINTS):
            self.motors_body[i].setPosition(positions[i])
        for i in range(self.N_LEGS):
            self.motors_legs[i].setPosition(positions[self.N_BODY_JOINTS + i] -
                                            np.pi / 2)

        # Log data
        self.log_iteration()
Ejemplo n.º 8
0
class SalamanderCMC(object):
    """Salamander robot for CMC"""

    N_BODY_JOINTS = 10
    N_LEGS = 4

    def __init__(self, robot, n_iterations, parameters, logs="logs/log.npz"):
        super(SalamanderCMC, self).__init__()
        self.robot = robot
        timestep = int(robot.getBasicTimeStep())
        self.network = SalamanderNetwork(1e-3*timestep, parameters)

        # Position sensors
        self.position_sensors = [
            self.robot.getPositionSensor('position_sensor_{}'.format(i+1))
            for i in range(self.N_BODY_JOINTS)
        ] + [
            self.robot.getPositionSensor('position_sensor_leg_{}'.format(i+1))
            for i in range(self.N_LEGS)
        ]
        for sensor in self.position_sensors:
            sensor.enable(timestep)

        # GPS
        self.gps = robot.getGPS("fgirdle_gps")
        self.gps.enable(timestep)

        # Get motors
        self.motors_body = [
            self.robot.getMotor("motor_{}".format(i+1))
            for i in range(self.N_BODY_JOINTS)
        ]
        self.motors_legs = [
            self.robot.getMotor("motor_leg_{}".format(i+1))
            for i in range(self.N_LEGS)
        ]

        # Set motors
        for motor in self.motors_body:
            motor.setPosition(0)
            motor.enableForceFeedback(timestep)
            motor.enableTorqueFeedback(timestep)
        for motor in self.motors_legs:
            motor.setPosition(-np.pi/2)
            motor.enableForceFeedback(timestep)
            motor.enableTorqueFeedback(timestep)

        # Iteration counter
        self.iteration = 0

        # Logging
        self.log = ExperimentLogger(
            n_iterations,
            n_links=1,
            n_joints=self.N_BODY_JOINTS+self.N_LEGS,
            filename=logs,
            timestep=1e-3*timestep,
            **parameters
        )
    def get_coord(self):
        return self.gps.getValues()

    def log_iteration(self):
        """Log state"""
        self.log.log_link_positions(self.iteration, 0, self.gps.getValues())
        for i, motor in enumerate(self.motors_body):
            # Position
            self.log.log_joint_position(
                self.iteration, i,
                self.position_sensors[i].getValue()
            )
            # Command
            self.log.log_joint_cmd(
                self.iteration, i,
                motor.getTargetPosition()
            )
            # Torque
            self.log.log_joint_torque(
                self.iteration, i,
                motor.getTorqueFeedback()
            )
            # Torque feedback
            self.log.log_joint_torque_feedback(
                self.iteration, i,
                motor.getTorqueFeedback()
            )
        for i, motor in enumerate(self.motors_legs):
            # Position
            self.log.log_joint_position(
                self.iteration, 10+i,
                self.position_sensors[10+i].getValue()
            )
            # Command
            self.log.log_joint_cmd(
                self.iteration, 10+i,
                motor.getTargetPosition()
            )
            # Torque
            self.log.log_joint_torque(
                self.iteration, 10+i,
                motor.getTorqueFeedback()
            )
            # Torque feedback
            self.log.log_joint_torque_feedback(
                self.iteration, 10+i,
                motor.getTorqueFeedback()
            )

    def step(self,parameters):
        """Step"""
        # Increment iteration
        self.iteration += 1
        # Update network
        self.network.step()
        positions = self.network.get_motor_position_output()
        if parameters.flag == "9g" and self.gps.getValues()[0]>0.4 and parameters.toggle=="walk":
            parameters.toggle="swim"
            parameters.drive=4
            print("d={}".format(parameters.drive))
            self.network.parameters.update(parameters)
        if parameters.flag == "9g" and self.gps.getValues()[0]<0.2 and parameters.toggle=="swim":
            parameters.drive=2
            print("d={}".format(parameters.drive))
            self.network.parameters.update(parameters)
            parameters.toggle="walk"
            self.network.reset_leg_phases()
            
        if parameters.flag=="9d1" and self.iteration==2000:
            print("turn on")
            parameters.turning=True
            self.network.parameters.update(parameters)
        if parameters.flag=="9d1" and self.iteration==4000:
            print("turn off")
            parameters.turning=False
            self.network.parameters.update(parameters)
        
        
        
        # Update control
        for i in range(self.N_BODY_JOINTS):
            self.motors_body[i].setPosition(positions[i])
        
        if np.all(abs(self.network.parameters.nominal_amplitudes[20:24])<=0.0001):
            for i in range(self.N_LEGS):
                self.motors_legs[i].setPosition(
                    round(self.position_sensors[i+10].getValue()/(2*math.pi))*2*math.pi-math.pi/2
                )
        else:  
            for i in range(self.N_LEGS):
                self.motors_legs[i].setPosition(
                    positions[self.N_BODY_JOINTS+i] - np.pi/2
                    
                )

        # Log data
        self.log_iteration()
class SalamanderCMC(object):
    """Salamander robot for CMC"""

    N_BODY_JOINTS = 10
    N_LEGS = 4
    X_HIGH_POS = 1.0
    X_LOW_POS = 0.25
    SWIM = True

    def __init__(self, robot, n_iterations, parameters, logs="logs/log.npz"):
        super(SalamanderCMC, self).__init__()
        self.robot = robot
        timestep = int(robot.getBasicTimeStep())
        self.network = SalamanderNetwork(1e-3*timestep, parameters)

        # Position sensors
        self.position_sensors = [
            self.robot.getPositionSensor('position_sensor_{}'.format(i+1))
            for i in range(self.N_BODY_JOINTS)
        ] + [
            self.robot.getPositionSensor('position_sensor_leg_{}'.format(i+1))
            for i in range(self.N_LEGS)
        ]
        for sensor in self.position_sensors:
            sensor.enable(timestep)

        # GPS
        self.enable = False
        self.gps = robot.getGPS("fgirdle_gps")
        self.gps.enable(timestep)

        # Get motors
        self.motors_body = [
            self.robot.getMotor("motor_{}".format(i+1))
            for i in range(self.N_BODY_JOINTS)
        ]
        self.motors_legs = [
            self.robot.getMotor("motor_leg_{}".format(i+1))
            for i in range(self.N_LEGS)
        ]

        # Set motors
        for motor in self.motors_body:
            motor.setPosition(0)
            motor.enableForceFeedback(timestep)
            motor.enableTorqueFeedback(timestep)
        for motor in self.motors_legs:
            motor.setPosition(-np.pi/2)
            motor.enableForceFeedback(timestep)
            motor.enableTorqueFeedback(timestep)

        # Iteration counter
        self.iteration = 0

        # Logging
        self.log = ExperimentLogger(
            n_iterations,
            n_links=1,
            n_joints=self.N_BODY_JOINTS+self.N_LEGS,
            filename=logs,
            timestep=1e-3*timestep,
            **parameters
        )

    def log_iteration(self):
        """Log state"""
        self.log.log_link_positions(self.iteration, 0, self.gps.getValues())
        for i, motor in enumerate(self.motors_body):
            # Position
            self.log.log_joint_position(
                self.iteration, i,
                self.position_sensors[i].getValue()
            )
            # Command
            self.log.log_joint_cmd(
                self.iteration, i,
                motor.getTargetPosition()
            )
            # Torque
            self.log.log_joint_torque(
                self.iteration, i,
                motor.getTorqueFeedback()
            )
            # Torque feedback
            self.log.log_joint_torque_feedback(
                self.iteration, i,
                motor.getTorqueFeedback()
            )
        for i, motor in enumerate(self.motors_legs):
            # Position
            self.log.log_joint_position(
                self.iteration, 10+i,
                self.position_sensors[10+i].getValue()
            )
            # Command
            self.log.log_joint_cmd(
                self.iteration, 10+i,
                motor.getTargetPosition()
            )
            # Torque
            self.log.log_joint_torque(
                self.iteration, 10+i,
                motor.getTorqueFeedback()
            )
            # Torque feedback
            self.log.log_joint_torque_feedback(
                self.iteration, 10+i,
                motor.getTorqueFeedback()
            )

    def step(self):
        """Step"""
        # Increment iteration
        self.iteration += 1

        # Update network
        self.network.step()
        positions = self.network.get_motor_position_output()

        # Update control
        for i in range(self.N_BODY_JOINTS):
            self.motors_body[i].setPosition(positions[i])
        for i in range(self.N_LEGS):
            self.motors_legs[i].setPosition(
                positions[self.N_BODY_JOINTS+i] - np.pi/2
            )

        # Log data
        self.log_iteration()

        # Retrieve GPS to change from walking to swimming
        if self.iteration == 1:
            self.enable = True

        pos = self.gps.getValues()

        if self.network.parameters.enable_transitions:
            if pos[0] > self.X_HIGH_POS:
                if not self.SWIM:
                    self.SWIM = True

                self.network.parameters.drive_left = 4.0
                self.network.parameters.drive_right = 4.0
            elif pos[0] < self.X_LOW_POS:
                self.network.parameters.drive_left = 2.0
                self.network.parameters.drive_right = 2.0

                if self.SWIM:
                    self.network.state.phases = 1e-4 * np.random.ranf(self.network.parameters.n_oscillators)
                    self.SWIM = False

            self.network.parameters.set_saturation_params(self.network.parameters)
            self.network.parameters.saturate_params()