コード例 #1
0
ファイル: trainer.py プロジェクト: dsp6414/pytlib
    def train(self):
        # load after a forward call for dynamic models
        batched_data,_,_ = load_samples(self.model.get_loader(),self.model.cuda,self.args.batch_size)
        self.evaluate_model(batched_data)
        self.iteration = load(self.args.output_dir,self.model.get_model(),self.iteration,self.model.get_optimizer())

        for i in range(self.iteration,self.iteration+self.args.iterations):
            #################### LOAD INPUTS ############################
            # TODO, make separate timer class if more complex timings arise
            t0 = time.time()
            batched_data,batched_targets,sample_array = load_samples(self.model.get_loader(),self.model.cuda,self.args.batch_size)
            self.logger.set('timing.input_loading_time',time.time() - t0)
            #############################################################

            #################### FORWARD ################################
            t1 = time.time()
            outputs = self.evaluate_model(batched_data)
            self.logger.set('timing.foward_pass_time',time.time() - t1)
            #############################################################

            #################### BACKWARD AND SGD  #####################
            t2 = time.time()
            loss = self.model.get_lossfn()(*(outputs + batched_targets))
            self.model.get_optimizer().zero_grad()
            loss.backward()
            self.model.get_optimizer().step()
            self.logger.set('timing.loss_backward_update_time',time.time() - t2)
            #############################################################

            #################### LOGGING, VIZ and SAVE ###################
            print 'iteration: {0} loss: {1}'.format(self.iteration,loss.data[0])

            if self.args.compute_graph and i==self.iteration:
                compute_graph(loss,output_file=os.path.join(self.args.output_dir,self.args.compute_graph))

            if self.iteration%self.args.save_iter==0:
                save(self.model.get_model(),self.model.get_optimizer(),self.iteration,self.args.output_dir)

            self.logger.set('time',time.time())
            self.logger.set('date',str(datetime.now()))
            self.logger.set('loss',loss.data[0])
            self.logger.set('iteration',self.iteration)
            self.logger.dump_line()
            self.iteration+=1

            if self.args.visualize_iter>0 and self.iteration%self.args.visualize_iter==0:
                Batcher.debatch_outputs(sample_array,outputs)
                map(lambda x:x.visualize({'title':random_str(5)}),sample_array)
                ImageVisualizer().dump_image(os.path.join(self.args.output_dir,'visualizations_{0:08d}.svg'.format(self.iteration)))
コード例 #2
0
    def test(self):
        # load after a forward call for dynamic models
        batched_data, _, _ = load_samples(self.model.get_loader(),
                                          self.model.cuda,
                                          self.args.batch_size)
        self.evaluate_model(batched_data)
        self.iteration = load(self.args.output_dir, self.model.get_model(),
                              self.iteration)

        for i in range(self.iteration, self.iteration + self.args.iterations):
            #################### LOAD INPUTS ############################
            t0 = time.time()
            batched_data, batched_targets, sample_array = load_samples(
                self.model.get_loader(), self.model.cuda, self.args.batch_size)
            self.logger.set('timing.input_loading_time', time.time() - t0)
            #############################################################

            #################### FORWARD ################################
            t1 = time.time()
            outputs = self.evaluate_model(batched_data)
            self.logger.set('timing.foward_pass_time', time.time() - t1)
            #############################################################

            #################### LOGGING, VIZ ###################
            print('iteration: {0}'.format(self.iteration))

            self.logger.set('time', time.time())
            self.logger.set('date', str(datetime.now()))
            self.logger.set('iteration', self.iteration)
            self.logger.dump_line()
            self.iteration += 1

            Batcher.debatch_outputs(sample_array, outputs)
            list(
                map(
                    lambda x: x.visualize({
                        'title': random_str(5),
                        'mode': 'test'
                    }), sample_array))
            if self.args.visualize_iter > 0 and self.iteration % self.args.visualize_iter == 0:
                print('dumping {}'.format('testviz_{0:08d}.svg'.format(
                    self.iteration)))
                ImageVisualizer().dump_image(
                    os.path.join(self.args.output_dir,
                                 'testviz_{0:08d}.svg'.format(self.iteration)))