示例#1
0
    def on_train_begin(self, logs=None):
        sess = K.get_session()

        # Split variables based on type -> float32 vs all else
        test_v = tf.Variable([0], dtype=tf.float32)
        all_vars = tf.trainable_variables()
        float_vars = [v for v in all_vars if v.dtype == test_v.dtype]
        other_vars = [v for v in all_vars if v.dtype != test_v.dtype]

        # Initialize variables and broadcast from head node
        sess.run(tf.variables_initializer(all_vars))
        new_vars = mc.broadcast(float_vars, 0)
        bcast = tf.group(
            *[tf.assign(v, new_vars[k]) for k, v in enumerate(float_vars)])
        sess.run(bcast)

        # Validate Broadcast
        if self.validate:
            py_all_vars = [sess.run(v) for v in float_vars]
            var_types = [
                np.array([v]) if type(v) == np.float32 else v
                for v in py_all_vars
            ]
            if mc.get_rank() is 0:
                if (mc.check_buffers_match(var_types, 1) != 0):
                    tf.logging.error(
                        "Not all processes have the same initial model!")
                else:
                    tf.logging.info("Initial model is consistent on all ranks")
示例#2
0
 def begin(self):
     if not self.bcast:
         new_vars = mc.broadcast(tf.trainable_variables(), 0)
         self.bcast = tf.group(*[
             tf.assign(v, new_vars[k])
             for k, v in enumerate(tf.trainable_variables())
         ])
示例#3
0
    def train(self):
        train_step, loss, lossL1Train,train_true,train_predict = self.optimize()
        lossL1Val,val_true,val_predict = self.validation_loss()
        lossL1Test,test_true,test_predict = self.test_loss()
        
	config = tf.ConfigProto()
        config.gpu_options.per_process_gpu_memory_fraction = 0.4
 
        ### taking config from the MKL benchmarks. 
        config.allow_soft_placement = True
        config.intra_op_parallelism_threads = 1 ## default
        config.inter_op_parallelism_threads = 2 ## Default

        #used to save the model
	saver = tf.train.Saver()
        global best_validation_accuracy
        global last_improvement
        global total_iterations
	best_validation_accuracy = 1.0         #Best validation accuracy seen so far
	last_improvement = 0                   #Iteration-number for last improvement to validation accuracy.
	require_improvement = hp.RUNPARAM['require_improvement']               #Stop optimization if no improvement found in this many iterations.
        total_iterations = 0                   #Counter for total number of iterations performed so far.        

        #initialize the CPE ML Plugin with one team (single thread for now) and the model size
        totsize = sum([reduce(lambda x, y: x*y, v.get_shape().as_list()) for v in tf.trainable_variables()])
        mc.init(1, 1, totsize, "tensorflow")
        hp.RUNPARAM['batch_per_epoch'] = hp.RUNPARAM['batch_per_epoch'] / mc.get_nranks()
        hp.RUNPARAM['batch_per_epoch_val'] = hp.RUNPARAM['batch_per_epoch_val'] / mc.get_nranks()
        totsteps = hp.RUNPARAM['num_epoch'] * hp.RUNPARAM['batch_per_epoch']
        mc.config_team(0, 0, totsteps, totsteps, 2, 50)

        if (mc.get_rank() == 0):
            print("+------------------------------+")
            print("| CosmoFlow                    |")
            print("| # Ranks = {:5d}              |".format(mc.get_nranks()))
            print("| Global Batch = {:6d}        |".format(mc.get_nranks() * hp.Input['BATCH_SIZE']))
            print("| # Parameters = {:9d}     |".format(totsize))
            print("+------------------------------+")

        #use the CPE ML Plugin to broadcast initial model parameter values
        new_vars = mc.broadcast(tf.trainable_variables(),0)
        bcast    = tf.group(*[tf.assign(v,new_vars[k]) for k,v in enumerate(tf.trainable_variables())])

	if(self.is_train):
            with tf.Session(config=config) as sess:
        	losses_train = []  
        	losses_val = []
        	losses = []
		val_accuracys = []       
		data_accuracys = []   

                #do all parameter initializations
		sess.run(tf.global_variables_initializer())
		sess.run(tf.local_variables_initializer())
                sess.run(bcast)
		
        	coord = tf.train.Coordinator()
        	threads = tf.train.start_queue_runners(coord=coord)

                elapsed_time = 0.
		for epoch in range(hp.RUNPARAM['num_epoch']):
			save_path = os.path.join(hp.Path['Model_path'], 'best_validation')
			total_iterations += 1
			start_time = time.time()
        	        loss_per_epoch_val = 0
        	        loss_per_epoch_train = 0
        	        for i in range(hp.RUNPARAM['batch_per_epoch']): 
				step_start_time = time.time()
				_,lossTrain,lossL1Train_,train_true_,train_predict_ = sess.run([train_step,loss,lossL1Train,train_true,train_predict])
                                step_finish_time = time.time()
				
                                elapsed_time += (step_finish_time-step_start_time)
                                samps_per_sec = mc.get_nranks() * (epoch * hp.RUNPARAM['batch_per_epoch'] * hp.Input['BATCH_SIZE'] + (i+1) * hp.Input['BATCH_SIZE']) / elapsed_time
                                if (mc.get_rank() == 0):
                                  print("Train Step: " + str(i) + ", Samples/Sec = " + str(samps_per_sec) + ", Loss = " + str(lossTrain))
                        
        	                loss_per_epoch_train +=lossL1Train_

                        global_loss = np.array([loss_per_epoch_train],dtype=np.float32)
                        mc.average(global_loss)
                        loss_per_epoch_train = global_loss / hp.RUNPARAM['batch_per_epoch']
        	        losses.append(loss_per_epoch_train)
			losses_train.append(loss_per_epoch_train)
			
			for i in range(hp.RUNPARAM['batch_per_epoch_val']):
                                if (mc.get_rank() == 0):
                                  print("Val Step = " + str(i))
				loss_,val_true_,val_predict_ = sess.run([lossL1Val,val_true,val_predict])
                                loss_per_epoch_val += loss_

                        global_loss = np.array([loss_per_epoch_val],dtype=np.float32)
                        mc.average(global_loss)
                        loss_per_epoch_val = global_loss / hp.RUNPARAM['batch_per_epoch_val']
			losses_val.append(loss_per_epoch_val)

        	        if(loss_per_epoch_val < best_validation_accuracy):
				best_validation_accuracy  = loss_per_epoch_val
				last_improvement = total_iterations
				if (mc.get_rank() == 0):
					saver.save(sess=sess, save_path=save_path)

			if (mc.get_rank() == 0):
				print("Epoch {} took {:.3f}s".format(epoch, time.time() - start_time))
				print "  training loss: %.3f" %(loss_per_epoch_train)
				print "  validation loss: %.3f" %(loss_per_epoch_val)
				print "  best loss: %.3f"%best_validation_accuracy	
				np.savetxt(os.path.join(hp.Path['train_result'],'loss_train.txt'),losses_train)
				np.savetxt(os.path.join(hp.Path['val_result'],'loss_val.txt'),losses_val)
				np.savetxt(os.path.join(hp.Path['train_result'],'losses.txt'),losses)
		                #np.savetxt(os.path.join(hp.Path['train_result'],'train_pred'+str(epoch)+'.txt'),np.c_[train_true_,train_predict_])
        	                #np.savetxt(os.path.join(hp.Path['val_result'],'val_pred'+str(epoch)+'.txt'),np.c_[val_true_,val_predict_])
			if(total_iterations - last_improvement > require_improvement):
				if (mc.get_rank() == 0):
					print ("No improvement found in a while, stopping optimization.")
				break		                        

		coord.request_stop();
                coord.join(threads);

	if(self.is_test and mc.get_rank() == 0):
               
		save_path = os.path.join(hp.Path['Model_path'], 'best_validation')
		if self.save_path != None:
		    save_path = self.save_path

		with tf.Session() as sess:
	    		saver.restore(sess=sess,save_path=save_path)
			coord = tf.train.Coordinator()
                	threads = tf.train.start_queue_runners(coord=coord)
            		loss_test = []
            		for i in range(0,hp.RUNPARAM['iter_test']):
				start_time = time.time()
		    		lossL1Test_,test_true_,test_predict_ = sess.run([lossL1Test,test_true,test_predict])
		    		loss_test.append(lossL1Test_)	
				print("Box {} took {:.3f}s".format(i, time.time() - start_time))
				print "  test loss: %.3f"%lossL1Test_
	    		        np.savetxt(os.path.join(hp.Path['test_result'],'test_batch_'+str(i)+'.txt'),np.c_[test_true_,test_predict_])
	    		np.savetxt(os.path.join(hp.Path['test_result'],'loss_test.txt'),loss_test)
                	coord.request_stop()
			coord.join(threads)

        #cleanup the CPE ML Plugin
        mc.finalize()