Esempio n. 1
0
                K.tf.Session(
                    config=K.tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False,
                                            gpu_options=gpu_options)))
        if args.tf:
            tf_device = device
            if hide_device:
                tf_device = 'gpu0' if 'gpu' in device else ''
            model_builder = ModelTensorFlow(comm,
                                            source=args.model,
                                            device_name=tf_device,
                                            weights=model_weights)
            logging.debug("Using device {}".format(model_builder.device))
        else:
            model_builder = ModelFromJson(comm,
                                          args.model,
                                          weights=model_weights)
            logging.debug("using device {}".format(device))
            os.environ[
                'THEANO_FLAGS'] = "profile=%s,device=%s,floatX=float32" % (
                    args.profile, device.replace('gpu', 'cuda'))
            # GPU ops need to be executed synchronously in order for profiling to make sense
        if args.profile:
            os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

    data = H5Data(batch_size=args.batch,
                  cache=args.caching_dir,
                  preloading=args.data_preload,
                  features_name=args.features_name,
                  labels_name=args.labels_name)
    # We initialize the Data object with the training data list
Esempio n. 2
0
    def _execute_MPI(self,
                    comm=None,
                    # masters=1,
                    # easgd=False,
                    archiveTraining=True,
                    archiveValidation=True,
                    verbose=1):
        from mpi4py import MPI
        from mpi_learn.mpi.manager import MPIManager, get_device
        from mpi_learn.train.algo import Algo
        from mpi_learn.train.data import H5Data
        from mpi_learn.train.model import ModelFromJson

        #return prep_func
        #print(self.custom_objects)
        #print(custom_objects)
        #print(Lorentz, Slice)
        #raise ValueError()
        load_weights = True
        # synchronous = False
        # sync_every = 1
        # MPIoptimizer = "rmsprop"
        # batch_size = 100
        
        if(comm == None):
            comm = MPI.COMM_WORLD.Dup()



        # if(not isinstance(self.train_procedure,list)): self.train_procedure = [self.train_procedure]
        # if(not isinstance(self.val_procedure,list)): self.val_procedure = [self.val_procedure]
        if(not(isinstance(self.train_procedure,list))):
            raise ValueError("Trial attribute train_procedure: expected list of DataProcedures or paths but got type %r" % type(self.train_procedure))
        if(not(isinstance(self.val_procedure,list))):
            raise ValueError("Trial attribute val_procedure: expected list of DataProcedures or paths but got type %r" % type(self.val_procedure))

        train = [DataProcedure.from_json(self.archive_dir,x) if isinstance(x,DataProcedure) else str(x) for x in self.train_procedure]
        val = [DataProcedure.from_json(self.archive_dir,x) if isinstance(x,DataProcedure) else str(x) for x in self.val_procedure]

        # if(not isinstance(train, list) or not False in [isinstance(x,DataProcedure) or isinstance(x,string_types) for x in train]):
        #     raise ValueError("Train procedure must be list of DataProcedures")
        # if(not isinstance(val, list) or not False in [isinstance(x, DataProcedure) or isinstance(x, string_types) for x in val]):
        #     raise ValueError("Validation procedure must be list of DataProcedures")
        batchAssertArchived(train)
        batchAssertArchived(val)
        def assertStr(x):
            if(isinstance(x,DataProcedure)):
                return dp.get_path() + "archive.h5"
            elif(os.path.isfile(x)):
                return x  
            else:
                raise IOError("Cannot find %r" % x)
                
        train_list = [assertStr(x) for x in train]
        val_list = [assertStr(x) for dp in val]
        # print("Train List:", train_list)
        # print("Val List:", val_list)

        # There is an issue when multiple processes import Keras simultaneously --
        # the file .keras/keras.json is sometimes not read correctly.  
        # as a workaround, just try several times to import keras.
        # Note: importing keras imports theano -- 
        # impossible to change GPU choice after this.
        for try_num in range(10):
            try:
                from keras.models import model_from_json
                import keras.callbacks as cbks
                break
            except ValueError:
                print "Unable to import keras. Trying again: %d" % try_num
                sleep(0.1)


        custom_objects = {}
        for name, module in self.custom_objects.items():
            try:
                #my_module = importlib.import_module('os.path')
                custom_objects[name] = getattr(importlib.import_module(module), name)
                #exec("from " + module +  " import " + name)
            except:
                raise ValueError("Custom Object %r does not exist in %r. \
                    For best results Custom Objects should be importable and not locally defined." % (str(name), str(module)))

        # We initialize the Data object with the training data list
        # so that we can use it to count the number of training examples

        data = H5Data(batch_size=self.batch_size, 
                features_name=self.features_name, labels_name=self.labels_name)
        data.set_file_names(train_list)
        num_train = data.count_data()
        


        # if comm.Get_rank() == 0:
        validate_every = num_train/self.batch_size
       
        

        if self.easgd:
            # raise NotImplementedError("Not implemented")
            algo = Algo(None, loss=self.loss, validate_every=validate_every,
                    mode='easgd', elastic_lr=1.0, sync_every=self.sync_every,
                    worker_optimizer='sgd',
                    elastic_force=0.9/(comm.Get_size()-1)) 
        else:
            algo = Algo(self.master_optimizer, loss=self.loss, validate_every=validate_every,
                    sync_every=self.sync_every, worker_optimizer=self.optimizer) 

        #model = self.compile(custom_objects=custom_objects)
        #model_arch = model.to_json()
        #print(self.get_path()+"trial.json")
        model_builder = ModelFromJson( comm,json_str=self.model,custom_objects=custom_objects )

        callbacks = self._generateCallbacks(verbose=verbose)

        # Creating the MPIManager object causes all needed worker and master nodes to be created
        manager = MPIManager(comm=comm, data=data, num_epochs=self.epochs if hasattr(self,'epochs') else self.nb_epoch,
                             algo=algo, model_builder=model_builder,
                             train_list=train_list, val_list=val_list, num_masters=self.masters,
                             synchronous=self.synchronous, callbacks=callbacks, custom_objects=custom_objects)


        # Process 0 defines the model and propagates it to the workers.
        if comm.Get_rank() == 0:
            record = self.read_record()
            if(not "num_train" in record):
                self.to_record({"num_train": num_train})
            if(not "num_val" in record):
                val_data = H5Data( val_list, batch_size=self.batch_size,
                features_name=self.features_name, labels_name=self.labels_name)
                self.to_record({"num_val": val_data.count_data()})

            print(custom_objects)
            
            
            print algo
            #weights = model.get_weights()

            #manager.process.set_model_info( model_arch, algo, weights )
            t_0 = time()
            histories = manager.process.train() 
            delta_t = time() - t_0
            manager.free_comms()
            print "Training finished in %.3f seconds" % delta_t
            print(histories)