Esempio n. 1
0
    def train(self, train_loader, val_loader=None):
        ''' Given data queues, train the network '''
        # Parameter directory
        save_dir = os.path.join(cfg.DIR.OUT_PATH)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        # Timer for the training op and parallel data loading op.
        train_timer = Timer()
        data_timer = Timer()
        training_losses = []

        # Setup learning rates
        lr_steps = [int(k) for k in cfg.TRAIN.LEARNING_RATES.keys()]

        #Setup the lr_scheduler
        self.lr_scheduler = lr_scheduler.MultiStepLR(self.optimizer,
                                                     lr_steps,
                                                     gamma=0.1)

        start_iter = 0
        # Resume training
        if cfg.TRAIN.RESUME_TRAIN:
            self.load(cfg.CONST.WEIGHTS)
            start_iter = cfg.TRAIN.INITIAL_ITERATION

        # Main training loop
        train_loader_iter = iter(train_loader)
        for train_ind in range(start_iter, cfg.TRAIN.NUM_ITERATION + 1):
            self.lr_scheduler.step()

            data_timer.tic()
            try:
                batch_img, batch_voxel = train_loader_iter.next()
            except StopIteration:
                train_loader_iter = iter(train_loader)
                batch_img, batch_voxel = train_loader_iter.next()

            data_timer.toc()

            if self.net.is_x_tensor4:
                batch_img = batch_img[0]

            # Apply one gradient step
            train_timer.tic()
            loss = self.train_loss(batch_img, batch_voxel)
            train_timer.toc()

            training_losses.append(loss.item())

            # Decrease learning rate at certain points
            if train_ind in lr_steps:
                #for pytorch optimizer, learning rate can only be set when the optimizer is created
                #or using torch.optim.lr_scheduler
                print('Learing rate decreased to %f: ' %
                      cfg.TRAIN.LEARNING_RATES[str(train_ind)])

            # Debugging modules
            #
            # Print status, run validation, check divergence, and save model.
            if train_ind % cfg.TRAIN.PRINT_FREQ == 0:
                # Print the current loss
                print('%s Iter: %d Loss: %f' %
                      (datetime.now(), train_ind, loss))

            if train_ind % cfg.TRAIN.VALIDATION_FREQ == 0 and val_loader is not None:
                # Print test loss and params to check convergence every N iterations

                val_losses = 0
                val_num_iter = min(cfg.TRAIN.NUM_VALIDATION_ITERATIONS,
                                   len(val_loader))
                val_loader_iter = iter(val_loader)
                for i in range(val_num_iter):
                    batch_img, batch_voxel = val_loader_iter.next()
                    val_loss = self.train_loss(batch_img, batch_voxel)
                    val_losses += val_loss
                var_losses_mean = val_losses / val_num_iter
                print('%s Test loss: %f' % (datetime.now(), var_losses_mean))

            if train_ind % cfg.TRAIN.NAN_CHECK_FREQ == 0:
                # Check that the network parameters are all valid
                nan_or_max_param = max_or_nan(self.net.parameters())
                if has_nan(nan_or_max_param):
                    print('NAN detected')
                    break

            if (train_ind % cfg.TRAIN.SAVE_FREQ == 0 and not train_ind == 0) or \
                    (train_ind == cfg.TRAIN.NUM_ITERATION):
                # Save the checkpoint every a few iterations or at the end.
                self.save(training_losses, save_dir, train_ind)

            #loss is a Variable containing torch.FloatTensor of size 1
            if loss.item() > cfg.TRAIN.LOSS_LIMIT:
                print("Cost exceeds the threshold. Stop training")
                break
Esempio n. 2
0
    def train(self, train_queue, val_queue=None):
        ''' Given data queues, train the network '''
        # Parameter directory
        save_dir = os.path.join(cfg.DIR.OUT_PATH)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        # Timer for the training op and parallel data loading op.
        train_timer = Timer()
        data_timer = Timer()
        training_losses = []

        # Setup learning rates
        lr_steps = [int(k) for k in cfg.TRAIN.LEARNING_RATES.keys()]

        #Setup the lr_scheduler
        self.lr_scheduler = lr_scheduler.MultiStepLR(self.optimizer,
                                                     lr_steps,
                                                     gamma=0.1)  # gamma为下降系数

        start_iter = 0
        # Resume training
        if cfg.TRAIN.RESUME_TRAIN:
            self.load(cfg.CONST.WEIGHTS)
            start_iter = cfg.TRAIN.INITIAL_ITERATION

        if cfg.TRAIN.SHOW_LOSS:  # 要动态打印
            import matplotlib.pyplot as plot
            plot.figure(1, figsize=(12, 5))
            plot.ion()

        # Main training loop
        for train_ind in range(start_iter, cfg.TRAIN.NUM_ITERATION + 1):
            self.lr_scheduler.step()

            data_timer.tic()
            batch_img, batch_voxel = train_queue.get()
            data_timer.toc()

            if self.net.is_x_tensor4:
                batch_img = batch_img[0]

            # Apply one gradient step
            train_timer.tic()
            loss = self.train_loss(batch_img, batch_voxel)
            train_timer.toc()

            #             print(loss)
            # training_losses.append(loss.data)  转换为numpy数组
            # print(type(loss))
            # print(loss.data.numpy())
            # print(loss.data.numpy().shape)
            # print(type(loss.data.numpy()))
            if (torch.cuda.is_available()):
                training_losses.append(loss.cpu().data.numpy())
            else:
                training_losses.append(loss.data.numpy())

            # Decrease learning rate at certain points
            if train_ind in lr_steps:
                #for pytorch optimizer, learning rate can only be set when the optimizer is created
                #or using torch.optim.lr_scheduler
                print('Learing rate decreased to %f: ' %
                      cfg.TRAIN.LEARNING_RATES[str(train_ind)])

            # '''
            # Debugging modules
            # '''

            # Print status, run validation, check divergence, and save model.
            if train_ind % cfg.TRAIN.PRINT_FREQ == 0:  #40
                # Print the current loss
                print('%s Iter: %d Loss: %f' %
                      (datetime.now(), train_ind, loss))
                '''
                @TODO(dingyadong): loss dynamic Visualization
                '''
                # plot
                if (train_ind != 0):
                    steps = np.linspace(0,
                                        train_ind,
                                        train_ind + 1,
                                        dtype=np.float32)
                    if cfg.TRAIN.SHOW_LOSS:  # 要动态打印
                        plot.plot(steps, training_losses, 'b-')
                        plot.draw()
                # plot.pause(0.05)

            if train_ind % cfg.TRAIN.VALIDATION_FREQ == 0 and val_queue is not None:
                # Print test loss and params to check convergence every N iterations

                val_losses = 0
                for i in range(cfg.TRAIN.NUM_VALIDATION_ITERATIONS):
                    batch_img, batch_voxel = val_queue.get()
                    val_loss = self.train_loss(batch_img, batch_voxel)
                    val_losses += val_loss
                var_losses_mean = val_losses / cfg.TRAIN.NUM_VALIDATION_ITERATIONS
                print('%s Test loss: %f' % (datetime.now(), var_losses_mean))

            if train_ind % cfg.TRAIN.NAN_CHECK_FREQ == 0:
                # Check that the network parameters are all valid
                nan_or_max_param = max_or_nan(self.net.parameters())
                if has_nan(nan_or_max_param):
                    print('NAN detected')
                    break

            if train_ind % cfg.TRAIN.SAVE_FREQ == 0 and not train_ind == 0:
                self.save(training_losses, save_dir, train_ind)

            #loss is a Variable containing torch.FloatTensor of size 1
            if loss.data > cfg.TRAIN.LOSS_LIMIT:
                print("Cost exceeds the threshold. Stop training")
                break

        if cfg.TRAIN.SHOW_LOSS:  # 要动态打印ƒ
            plot.ioff()
            plot.show()