Beispiel #1
0
    def init_dataset(self, ):
        opts = self.opts
        dataloader = behave_data_utils.behave_dataoader(opts.dataloader,
                                                        shuffle=True)
        self.dataloader = dataloader

        val_opts = deepcopy(opts)
        val_opts.dataloader.sample_all_verts = True
        val_opts.dataloader.batch_size = 2
        self.val_dataloader = behave_data_utils.behave_dataoader(
            val_opts.dataloader, shuffle=True)
        return
Beispiel #2
0
    def train(self):
        opts = self.opts
        dataloader = behave_data_utils.behave_dataoader(opts.dataloader)
        num_epochs = opts.train.num_epochs
        self.total_steps = total_steps = 0
        self.epoch_iter = 0
        for epoch in range(num_epochs):
            self.epoch_iter = 0
            self.epoch = epoch
            if epoch > 500:
                batch_norm_utils.turnNormOff(self.model)
            for bx, batch in enumerate(dataloader):
                model_inputs = self.set_input(batch)
                self.forward(model_inputs)
                self.backward()
                total_steps += 1
                self.total_steps = total_steps
                self.epoch_iter += 1

                if total_steps % opts.logging.log_freq == 0:
                    self.print_scalars()
                    scalars = self.get_current_scalars()
                    self.log_tb_scalars(step=total_steps,
                                        train_scalars=scalars)

                if total_steps % opts.logging.val_log_freq == 0:
                    visuals, val_scalars = self.create_visuals()
                    self.log_to_tensorboard(visuals[0], )
                    self.log_tb_scalars(step=self.total_steps,
                                        val_scalars=val_scalars)

            self.epoch += 1
            if self.epoch % opts.logging.save_epoch_freq == 0:
                self.save_network(self.model, "model", self.epoch)
Beispiel #3
0
 def init_dataset(self, ):
     opts = self.opts
     dataloader = behave_data_utils.behave_dataoader(opts.dataloader,
                                                     shuffle=True)
     self.dataloader = dataloader
     return