Beispiel #1
0
 def get_lr_scheduler(self):
     if args.lr_decay_period > 0:
         lr_decay_epoch = list(
             range(args.lr_decay_period, args.epochs, args.lr_decay_period))
     else:
         lr_decay_epoch = [int(i) for i in args.lr_decay_epoch.split(',')]
     lr_decay_epoch = [e - args.warmup_epochs for e in lr_decay_epoch]
     num_batches = args.num_samples // args.batch_size
     lr_scheduler = gutils.LRSequential([
         gutils.LRScheduler('linear',
                            base_lr=0,
                            target_lr=args.lr,
                            nepochs=args.warmup_epochs,
                            iters_per_epoch=num_batches),
         gutils.LRScheduler(args.lr_mode,
                            base_lr=args.lr,
                            nepochs=args.epochs - args.warmup_epochs,
                            iters_per_epoch=num_batches,
                            step_epoch=lr_decay_epoch,
                            step_factor=args.lr_decay,
                            power=2),
     ])
     return lr_scheduler
Beispiel #2
0
    def train(self):
        if args.lr_decay_period > 0:
            lr_decay_epoch = list(
                range(args.lr_decay_period, args.epochs, args.lr_decay_period))
        else:
            lr_decay_epoch = [int(i) for i in args.lr_decay_epoch.split(',')]
        lr_decay_epoch = [e - args.warmup_epochs for e in lr_decay_epoch]
        num_batches = args.num_samples // args.batch_size
        lr_scheduler = gutils.LRSequential([
            gutils.LRScheduler('linear',
                               base_lr=0,
                               target_lr=args.lr,
                               nepochs=args.warmup_epochs,
                               iters_per_epoch=num_batches),
            gutils.LRScheduler(args.lr_mode,
                               base_lr=args.lr,
                               nepochs=args.epochs - args.warmup_epochs,
                               iters_per_epoch=num_batches,
                               step_epoch=lr_decay_epoch,
                               step_factor=args.lr_decay,
                               power=2),
        ])

        trainer = gluon.Trainer(self.net.collect_params(), 'sgd', {
            'wd': args.wd,
            'momentum': args.momentum,
            'lr_scheduler': lr_scheduler
        })  #

        # set up logger
        logging.basicConfig()
        logger = logging.getLogger()
        logger.setLevel(logging.INFO)
        log_file_path = args.save_prefix + '_train.log'
        log_dir = os.path.dirname(log_file_path)
        if log_dir and not os.path.exists(log_dir):
            os.makedirs(log_dir)
        fh = logging.FileHandler(log_file_path)
        logger.addHandler(fh)
        logger.info(args)

        logger.info('Start training from [Epoch {}]'.format(args.start_epoch))
        best_loss, best_map = 1000, 0
        for epoch in range(args.start_epoch, args.epochs):
            tic = time.time()
            btic = time.time()
            self.net.hybridize()
            for i, batch in enumerate(self.train_dataloader):
                data = gluon.utils.split_and_load(batch[0], ctx_list=self.ctx)
                cls_targ = gluon.utils.split_and_load(batch[1],
                                                      ctx_list=self.ctx)
                box_targ = gluon.utils.split_and_load(batch[2],
                                                      ctx_list=self.ctx)
                box_mask = gluon.utils.split_and_load(batch[3],
                                                      ctx_list=self.ctx)
                seg_gt = gluon.utils.split_and_load(batch[4],
                                                    ctx_list=self.ctx)
                mask = gluon.utils.split_and_load(batch[5], ctx_list=self.ctx)
                sum_losses, cls_losses, box_losses, seg_losses = [], [], [], []
                with mx.autograd.record():
                    for d, ct, bt, bm, sg, m in zip(data, cls_targ, box_targ,
                                                    box_mask, seg_gt, mask):
                        cls_pred, box_pred, _, seg_pred = self.net(d)
                        pred = {
                            'cls_pred': cls_pred,
                            'box_pred': box_pred,
                            'seg_pred': seg_pred
                        }
                        lab = {
                            'cls_targ': ct,
                            'box_targ': bt,
                            'box_mask': bm,
                            'seg_gt': sg,
                            'mask': m
                        }
                        loss, metrics = self.loss(pred, lab)
                        sum_losses.append(loss)
                        cls_losses.append(metrics['cls_loss'])
                        box_losses.append(metrics['box_loss'])
                        seg_losses.append(metrics['seg_loss'])
                    mx.autograd.backward(sum_losses)
                trainer.step(1)
                self.sum_loss.update(0, sum_losses)
                self.cls_loss.update(0, cls_losses)
                self.box_loss.update(0, box_losses)
                self.seg_loss.update(0, seg_losses)

                if args.log_interval and not (i + 1) % args.log_interval:
                    name0, loss0 = self.sum_loss.get()
                    name1, loss1 = self.cls_loss.get()
                    name2, loss2 = self.box_loss.get()
                    name3, loss3 = self.seg_loss.get()
                    logger.info(
                        '[Epoch {}][Batch {}], LR: {:.2E}, Speed: {:.3f} samples/sec, {}={:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}'
                        .format(epoch, i + 1, trainer.learning_rate,
                                args.batch_size / (time.time() - btic), name0,
                                loss0, name1, loss1, name2, loss2, name3,
                                loss3))
                btic = time.time()
            name0, loss0 = self.sum_loss.get()
            name1, loss1 = self.cls_loss.get()
            name2, loss2 = self.box_loss.get()
            name3, loss3 = self.seg_loss.get()
            logger.info(
                '[Epoch {}] Training cost: {:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}'
                .format(epoch,
                        time.time() - tic, name0, loss0, name1, loss1, name2,
                        loss2, name3, loss3))

            if not (epoch + 1) % args.val_interval:
                # consider reduce the frequency of validation to save time
                mean_ap, seg_loss = self.validate(epoch, logger)
                if mean_ap[-1] > best_map and seg_loss < best_loss:
                    best_map, best_loss = mean_ap[-1], seg_loss
                    self.net.save_parameters('{:s}_best.params'.format(
                        args.save_prefix))
                if args.save_interval and (epoch +
                                           1) % args.save_interval == 0:
                    self.net.save_parameters(
                        '{:s}_{:04d}_{:.3f}.params'.format(
                            args.save_prefix, epoch + 1, best_map))
Beispiel #3
0
    def train(self):
        if args.lr_decay_period > 0:
            lr_decay_epoch = list(
                range(args.lr_decay_period, args.epochs, args.lr_decay_period))
        else:
            lr_decay_epoch = [int(i) for i in args.lr_decay_epoch.split(',')]
        lr_decay_epoch = [e - args.warmup_epochs for e in lr_decay_epoch]
        num_batches = args.num_samples // args.batch_size
        lr_scheduler = gutils.LRSequential([
            gutils.LRScheduler('linear',
                               base_lr=0,
                               target_lr=args.lr,
                               nepochs=args.warmup_epochs,
                               iters_per_epoch=num_batches),
            gutils.LRScheduler(args.lr_mode,
                               base_lr=args.lr,
                               nepochs=args.epochs - args.warmup_epochs,
                               iters_per_epoch=num_batches,
                               step_epoch=lr_decay_epoch,
                               step_factor=args.lr_decay,
                               power=2),
        ])

        trainer = gluon.Trainer(self.net.collect_params(), 'sgd', {
            'wd': args.wd,
            'momentum': args.momentum,
            'lr_scheduler': lr_scheduler
        })

        # set up logger
        logging.basicConfig()
        logger = logging.getLogger()
        logger.setLevel(logging.INFO)
        log_file_path = args.save_prefix + '_train.log'
        log_dir = os.path.dirname(log_file_path)
        if log_dir and not os.path.exists(log_dir):
            os.makedirs(log_dir)
        fh = logging.FileHandler(log_file_path)
        logger.addHandler(fh)
        logger.info(args)

        logger.info('Start training from [Epoch {}]'.format(args.start_epoch))
        best_acc = [0]

        for epoch in range(args.start_epoch, args.epochs):
            tic = time.time()
            btic = time.time()
            self.net.hybridize()

            for i, batch in enumerate(self.train_dataloader):
                src_data = gluon.utils.split_and_load(batch[0],
                                                      ctx_list=self.ctx)
                mask = batch[1][:, :, ::32, ::8]
                src_mask = gluon.utils.split_and_load(mask, ctx_list=self.ctx)
                src_targ = gluon.utils.split_and_load(batch[2],
                                                      ctx_list=self.ctx)
                tag_lab = gluon.utils.split_and_load(batch[3],
                                                     ctx_list=self.ctx)
                tag_mask = gluon.utils.split_and_load(batch[4],
                                                      ctx_list=self.ctx)
                l_list = []
                with mx.autograd.record():
                    for sd, sm, st, tl, tm in zip(src_data, src_mask, src_targ,
                                                  tag_lab, tag_mask):
                        states = self.net.begin_state(bs, sd.context)
                        outputs = self.net(sd, sm, st)
                        loss = self.loss(outputs, tl, tm.expand_dims(axis=2))
                        l_list.append(loss)
                    mx.autograd.backward(l_list)
                trainer.step(args.batch_size)
                mx.nd.waitall()
                self.acc_metric.update(outputs, tl, tm)
                self.loss_metric.update(0, l_list)
                if args.log_interval and not (i + 1) % args.log_interval:
                    name1, acc1 = self.acc_metric.get()
                    name2, loss2 = self.loss_metric.get()
                    logger.info(
                        '[Epoch {}][Batch {}], LR: {:.2E}, Speed: {:.3f} samples/sec, {}={:.3f}, {}={:.3f}'
                        .format(epoch, i, trainer.learning_rate,
                                args.batch_size / (time.time() - btic), name1,
                                acc1, name2, loss2))
                btic = time.time()

            name1, acc1 = self.acc_metric.get()
            name2, loss2 = self.loss_metric.get()
            logger.info(
                '[Epoch {}] Training cost: {:.3f}, {}={:.3f}, {}={:.3f}'.
                format(epoch, (time.time() - tic), name1, acc1, name2, loss2))
            if not epoch % args.val_interval:
                name, current_acc = self.evaluate()
                logger.info('[Epoch {}] Validation: {}={:.3f}'.format(
                    epoch, name, current_acc))

            if current_acc > best_acc[0]:
                best_acc[0] = current_acc
                self.net.save_parameters('{:s}_best.params'.format(
                    args.save_prefix, epoch, current_acc))
                with open(args.save_prefix + '_best_map.log', 'a') as f:
                    f.write('{:04d}:\t{:.4f}\n'.format(epoch, current_acc))
            if args.save_interval and epoch % args.save_interval == 0:
                self.net.save_parameters('{:s}_{:04d}_{:.4f}.params'.format(
                    args.save_prefix, epoch, current_acc))
            self.acc_metric.reset()
            self.loss_metric.reset()