def train(self): # load after a forward call for dynamic models batched_data,_,_ = load_samples(self.model.get_loader(),self.model.cuda,self.args.batch_size) self.evaluate_model(batched_data) self.iteration = load(self.args.output_dir,self.model.get_model(),self.iteration,self.model.get_optimizer()) for i in range(self.iteration,self.iteration+self.args.iterations): #################### LOAD INPUTS ############################ # TODO, make separate timer class if more complex timings arise t0 = time.time() batched_data,batched_targets,sample_array = load_samples(self.model.get_loader(),self.model.cuda,self.args.batch_size) self.logger.set('timing.input_loading_time',time.time() - t0) ############################################################# #################### FORWARD ################################ t1 = time.time() outputs = self.evaluate_model(batched_data) self.logger.set('timing.foward_pass_time',time.time() - t1) ############################################################# #################### BACKWARD AND SGD ##################### t2 = time.time() loss = self.model.get_lossfn()(*(outputs + batched_targets)) self.model.get_optimizer().zero_grad() loss.backward() self.model.get_optimizer().step() self.logger.set('timing.loss_backward_update_time',time.time() - t2) ############################################################# #################### LOGGING, VIZ and SAVE ################### print 'iteration: {0} loss: {1}'.format(self.iteration,loss.data[0]) if self.args.compute_graph and i==self.iteration: compute_graph(loss,output_file=os.path.join(self.args.output_dir,self.args.compute_graph)) if self.iteration%self.args.save_iter==0: save(self.model.get_model(),self.model.get_optimizer(),self.iteration,self.args.output_dir) self.logger.set('time',time.time()) self.logger.set('date',str(datetime.now())) self.logger.set('loss',loss.data[0]) self.logger.set('iteration',self.iteration) self.logger.dump_line() self.iteration+=1 if self.args.visualize_iter>0 and self.iteration%self.args.visualize_iter==0: Batcher.debatch_outputs(sample_array,outputs) map(lambda x:x.visualize({'title':random_str(5)}),sample_array) ImageVisualizer().dump_image(os.path.join(self.args.output_dir,'visualizations_{0:08d}.svg'.format(self.iteration)))
def test(self): # load after a forward call for dynamic models batched_data, _, _ = load_samples(self.model.get_loader(), self.model.cuda, self.args.batch_size) self.evaluate_model(batched_data) self.iteration = load(self.args.output_dir, self.model.get_model(), self.iteration) for i in range(self.iteration, self.iteration + self.args.iterations): #################### LOAD INPUTS ############################ t0 = time.time() batched_data, batched_targets, sample_array = load_samples( self.model.get_loader(), self.model.cuda, self.args.batch_size) self.logger.set('timing.input_loading_time', time.time() - t0) ############################################################# #################### FORWARD ################################ t1 = time.time() outputs = self.evaluate_model(batched_data) self.logger.set('timing.foward_pass_time', time.time() - t1) ############################################################# #################### LOGGING, VIZ ################### print('iteration: {0}'.format(self.iteration)) self.logger.set('time', time.time()) self.logger.set('date', str(datetime.now())) self.logger.set('iteration', self.iteration) self.logger.dump_line() self.iteration += 1 Batcher.debatch_outputs(sample_array, outputs) list( map( lambda x: x.visualize({ 'title': random_str(5), 'mode': 'test' }), sample_array)) if self.args.visualize_iter > 0 and self.iteration % self.args.visualize_iter == 0: print('dumping {}'.format('testviz_{0:08d}.svg'.format( self.iteration))) ImageVisualizer().dump_image( os.path.join(self.args.output_dir, 'testviz_{0:08d}.svg'.format(self.iteration)))