def train(net, train_data, val_data, eval_metric, batch_size, ctx, args): """Training pipeline""" kv = mx.kvstore.create(args.kv_store) net.collect_params().setattr('grad_req', 'null') net.collect_train_params().setattr('grad_req', 'write') for k, v in net.collect_params('.*bias').items(): v.wd_mult = 0.0 optimizer_params = {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum} if args.horovod: hvd.broadcast_parameters(net.collect_params(), root_rank=0) trainer = hvd.DistributedTrainer( net.collect_train_params(), # fix batchnorm, fix first stage, etc... 'sgd', optimizer_params ) else: trainer = gluon.Trainer( net.collect_train_params(), # fix batchnorm, fix first stage, etc... 'sgd', optimizer_params, update_on_kvstore=(False if args.amp else None), kvstore=kv) if args.amp: amp.init_trainer(trainer) # lr decay policy lr_decay = float(args.lr_decay) lr_steps = sorted([float(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()]) lr_warmup = float(args.lr_warmup) # avoid int division rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False) rpn_box_loss = mx.gluon.loss.HuberLoss(rho=1 / 9.) # == smoothl1 rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss() rcnn_box_loss = mx.gluon.loss.HuberLoss() # == smoothl1 rcnn_mask_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False) metrics = [mx.metric.Loss('RPN_Conf'), mx.metric.Loss('RPN_SmoothL1'), mx.metric.Loss('RCNN_CrossEntropy'), mx.metric.Loss('RCNN_SmoothL1'), mx.metric.Loss('RCNN_Mask')] rpn_acc_metric = RPNAccMetric() rpn_bbox_metric = RPNL1LossMetric() rcnn_acc_metric = RCNNAccMetric() rcnn_bbox_metric = RCNNL1LossMetric() rcnn_mask_metric = MaskAccMetric() rcnn_fgmask_metric = MaskFGAccMetric() metrics2 = [rpn_acc_metric, rpn_bbox_metric, rcnn_acc_metric, rcnn_bbox_metric, rcnn_mask_metric, rcnn_fgmask_metric] # set up logger logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) log_file_path = args.save_prefix + '_train.log' log_dir = os.path.dirname(log_file_path) if log_dir and not os.path.exists(log_dir): os.makedirs(log_dir) fh = logging.FileHandler(log_file_path) logger.addHandler(fh) logger.info(args) if args.verbose: logger.info('Trainable parameters:') logger.info(net.collect_train_params().keys()) logger.info('Start training from [Epoch {}]'.format(args.start_epoch)) best_map = [0] for epoch in range(args.start_epoch, args.epochs): if not args.disable_hybridization: net.hybridize(static_alloc=args.static_alloc) rcnn_task = ForwardBackwardTask(net, trainer, rpn_cls_loss, rpn_box_loss, rcnn_cls_loss, rcnn_box_loss, rcnn_mask_loss) executor = Parallel(args.executor_threads, rcnn_task) if not args.horovod else None while lr_steps and epoch >= lr_steps[0]: new_lr = trainer.learning_rate * lr_decay lr_steps.pop(0) trainer.set_learning_rate(new_lr) logger.info("[Epoch {}] Set learning rate to {}".format(epoch, new_lr)) for metric in metrics: metric.reset() tic = time.time() btic = time.time() base_lr = trainer.learning_rate for i, batch in enumerate(train_data): if epoch == 0 and i <= lr_warmup: # adjust based on real percentage new_lr = base_lr * get_lr_at_iter(i / lr_warmup, args.lr_warmup_factor) if new_lr != trainer.learning_rate: if i % args.log_interval == 0: logger.info( '[Epoch 0 Iteration {}] Set learning rate to {}'.format(i, new_lr)) trainer.set_learning_rate(new_lr) batch = split_and_load(batch, ctx_list=ctx) metric_losses = [[] for _ in metrics] add_losses = [[] for _ in metrics2] # losses = [] if executor is not None: for data in zip(*batch): executor.put(data) for j in range(len(ctx)): if executor is not None: result = executor.get() else: result = rcnn_task.forward_backward(list(zip(*batch))[0]) if (not args.horovod) or hvd.rank() == 0: for k in range(len(metric_losses)): metric_losses[k].append(result[k]) for k in range(len(add_losses)): add_losses[k].append(result[len(metric_losses) + k]) for metric, record in zip(metrics, metric_losses): metric.update(0, record) for metric, records in zip(metrics2, add_losses): for pred in records: metric.update(pred[0], pred[1]) trainer.step(batch_size) # update metrics if (not args.horovod or hvd.rank() == 0) and args.log_interval \ and not (i + 1) % args.log_interval: msg = ','.join(['{}={:.3f}'.format(*metric.get()) for metric in metrics + metrics2]) logger.info('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}'.format( epoch, i, args.log_interval * args.batch_size / (time.time() - btic), msg)) btic = time.time() # validate and save params if (not args.horovod) or hvd.rank() == 0: msg = ','.join(['{}={:.3f}'.format(*metric.get()) for metric in metrics]) logger.info('[Epoch {}] Training cost: {:.3f}, {}'.format( epoch, (time.time() - tic), msg)) if not (epoch + 1) % args.val_interval: # consider reduce the frequency of validation to save time map_name, mean_ap = validate(net, val_data, ctx, eval_metric, args) val_msg = '\n'.join(['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)]) logger.info('[Epoch {}] Validation: \n{}'.format(epoch, val_msg)) current_map = float(mean_ap[-1]) else: current_map = 0. save_params(net, logger, best_map, current_map, epoch, args.save_interval, args.save_prefix)
def train(net, train_data, val_data, eval_metric, batch_size, ctx, logger, args): """Training pipeline""" args.kv_store = 'device' if (args.amp and 'nccl' in args.kv_store) else args.kv_store kv = mx.kvstore.create(args.kv_store) net.collect_params().setattr('grad_req', 'null') net.collect_train_params().setattr('grad_req', 'write') for k, v in net.collect_params('.*bias').items(): v.wd_mult = 0.0 optimizer_params = {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum, } if args.clip_gradient > 0.0: optimizer_params['clip_gradient'] = args.clip_gradient if args.amp: optimizer_params['multi_precision'] = True if args.horovod: hvd.broadcast_parameters(net.collect_params(), root_rank=0) trainer = hvd.DistributedTrainer( net.collect_train_params(), # fix batchnorm, fix first stage, etc... 'sgd', optimizer_params ) else: trainer = gluon.Trainer( net.collect_train_params(), # fix batchnorm, fix first stage, etc... 'sgd', optimizer_params, update_on_kvstore=(False if args.amp else None), kvstore=kv) if args.amp: amp.init_trainer(trainer) # lr decay policy lr_decay = float(args.lr_decay) lr_steps = sorted([float(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()]) lr_warmup = float(args.lr_warmup) # avoid int division rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False) rpn_box_loss = mx.gluon.loss.HuberLoss(rho=args.rpn_smoothl1_rho) # == smoothl1 rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss() rcnn_box_loss = mx.gluon.loss.HuberLoss(rho=args.rcnn_smoothl1_rho) # == smoothl1 rcnn_mask_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False) metrics = [mx.metric.Loss('RPN_Conf'), mx.metric.Loss('RPN_SmoothL1'), mx.metric.Loss('RCNN_CrossEntropy'), mx.metric.Loss('RCNN_SmoothL1'), mx.metric.Loss('RCNN_Mask')] rpn_acc_metric = RPNAccMetric() rpn_bbox_metric = RPNL1LossMetric() rcnn_acc_metric = RCNNAccMetric() rcnn_bbox_metric = RCNNL1LossMetric() rcnn_mask_metric = MaskAccMetric() rcnn_fgmask_metric = MaskFGAccMetric() metrics2 = [rpn_acc_metric, rpn_bbox_metric, rcnn_acc_metric, rcnn_bbox_metric, rcnn_mask_metric, rcnn_fgmask_metric] async_eval_processes = [] logger.info(args) if args.verbose: logger.info('Trainable parameters:') logger.info(net.collect_train_params().keys()) logger.info('Start training from [Epoch {}]'.format(args.start_epoch)) best_map = [0] base_lr = trainer.learning_rate for epoch in range(args.start_epoch, args.epochs): rcnn_task = ForwardBackwardTask(net, trainer, rpn_cls_loss, rpn_box_loss, rcnn_cls_loss, rcnn_box_loss, rcnn_mask_loss, args.amp) executor = Parallel(args.executor_threads, rcnn_task) if not args.horovod else None if not args.disable_hybridization: net.hybridize(static_alloc=args.static_alloc) while lr_steps and epoch >= lr_steps[0]: new_lr = trainer.learning_rate * lr_decay lr_steps.pop(0) trainer.set_learning_rate(new_lr) logger.info("[Epoch {}] Set learning rate to {}".format(epoch, new_lr)) for metric in metrics: metric.reset() tic = time.time() btic = time.time() train_data_iter = iter(train_data) next_data_batch = next(train_data_iter) next_data_batch = split_and_load(next_data_batch, ctx_list=ctx) for i in range(len(train_data)): batch = next_data_batch if i + epoch * len(train_data) <= lr_warmup: # adjust based on real percentage new_lr = base_lr * get_lr_at_iter((i + epoch * len(train_data)) / lr_warmup, args.lr_warmup_factor) if new_lr != trainer.learning_rate: if i % args.log_interval == 0: logger.info('[Epoch {} Iteration {}] Set learning rate to {}' .format(epoch, i, new_lr)) trainer.set_learning_rate(new_lr) metric_losses = [[] for _ in metrics] add_losses = [[] for _ in metrics2] if executor is not None: for data in zip(*batch): executor.put(data) for j in range(len(ctx)): if executor is not None: result = executor.get() else: result = rcnn_task.forward_backward(list(zip(*batch))[0]) if (not args.horovod) or hvd.rank() == 0: for k in range(len(metric_losses)): metric_losses[k].append(result[k]) for k in range(len(add_losses)): add_losses[k].append(result[len(metric_losses) + k]) try: # prefetch next batch next_data_batch = next(train_data_iter) next_data_batch = split_and_load(next_data_batch, ctx_list=ctx) except StopIteration: pass for metric, record in zip(metrics, metric_losses): metric.update(0, record) for metric, records in zip(metrics2, add_losses): for pred in records: metric.update(pred[0], pred[1]) trainer.step(batch_size) if (not args.horovod or hvd.rank() == 0) and args.log_interval \ and not (i + 1) % args.log_interval: msg = ','.join(['{}={:.3f}'.format(*metric.get()) for metric in metrics + metrics2]) batch_speed = args.log_interval * args.batch_size / (time.time() - btic) speed.append(batch_speed) logger.info('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}'.format( epoch, i, batch_speed, msg)) btic = time.time() if speed: avg_batch_speed = sum(speed) / len(speed) # validate and save params if (not args.horovod) or hvd.rank() == 0: msg = ','.join(['{}={:.3f}'.format(*metric.get()) for metric in metrics]) logger.info('[Epoch {}] Training cost: {:.3f}, Speed: {:.3f} samples/sec, {}'.format( epoch, (time.time() - tic), avg_batch_speed, msg)) if not (epoch + 1) % args.val_interval: # consider reduce the frequency of validation to save time validate(net, val_data, async_eval_processes, ctx, eval_metric, logger, epoch, best_map, args) elif (not args.horovod) or hvd.rank() == 0: current_map = 0. save_params(net, logger, best_map, current_map, epoch, args.save_interval, args.save_prefix) for thread in async_eval_processes: thread.join()
def train(net, train_data, val_data, eval_metric, ctx, args): """Training pipeline""" net.collect_params().setattr('grad_req', 'null') net.collect_train_params().setattr('grad_req', 'write') for k, v in net.collect_params('.*beta|.*bias').items(): v.wd_mult = 0.0 if args.horovod: hvd.broadcast_parameters(net.collect_params(), root_rank=0) trainer = hvd.DistributedTrainer( net.collect_train_params(), # fix batchnorm, fix first stage, etc... 'sgd', {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum}) else: trainer = gluon.Trainer( net.collect_train_params(), # fix batchnorm, fix first stage, etc... 'sgd', {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum}, update_on_kvstore=(False if args.amp else None)) if args.amp: amp.init_trainer(trainer) # lr decay policy lr_decay = float(args.lr_decay) lr_steps = sorted([float(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()]) lr_warmup = float(args.lr_warmup) # avoid int division rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False) rpn_box_loss = mx.gluon.loss.HuberLoss(rho=1 / 9.) # == smoothl1 rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss() rcnn_box_loss = mx.gluon.loss.HuberLoss() # == smoothl1 rcnn_mask_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False) metrics = [mx.metric.Loss('RPN_Conf'), mx.metric.Loss('RPN_SmoothL1'), mx.metric.Loss('RCNN_CrossEntropy'), mx.metric.Loss('RCNN_SmoothL1'), mx.metric.Loss('RCNN_Mask')] rpn_acc_metric = RPNAccMetric() rpn_bbox_metric = RPNL1LossMetric() rcnn_acc_metric = RCNNAccMetric() rcnn_bbox_metric = RCNNL1LossMetric() rcnn_mask_metric = MaskAccMetric() rcnn_fgmask_metric = MaskFGAccMetric() metrics2 = [rpn_acc_metric, rpn_bbox_metric, rcnn_acc_metric, rcnn_bbox_metric, rcnn_mask_metric, rcnn_fgmask_metric] # set up logger logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) log_file_path = args.save_prefix + '_train.log' log_dir = os.path.dirname(log_file_path) if log_dir and not os.path.exists(log_dir): os.makedirs(log_dir) fh = logging.FileHandler(log_file_path) logger.addHandler(fh) logger.info(args) if args.verbose: logger.info('Trainable parameters:') logger.info(net.collect_train_params().keys()) logger.info('Start training from [Epoch {}]'.format(args.start_epoch)) best_map = [0] for epoch in range(args.start_epoch, args.epochs): while lr_steps and epoch >= lr_steps[0]: new_lr = trainer.learning_rate * lr_decay lr_steps.pop(0) trainer.set_learning_rate(new_lr) logger.info("[Epoch {}] Set learning rate to {}".format(epoch, new_lr)) for metric in metrics: metric.reset() tic = time.time() btic = time.time() if not args.disable_hybridization: net.hybridize(static_alloc=args.static_alloc) base_lr = trainer.learning_rate for i, batch in enumerate(train_data): if epoch == 0 and i <= lr_warmup: # adjust based on real percentage new_lr = base_lr * get_lr_at_iter(i / lr_warmup) if new_lr != trainer.learning_rate: if i % args.log_interval == 0: logger.info( '[Epoch 0 Iteration {}] Set learning rate to {}'.format(i, new_lr)) trainer.set_learning_rate(new_lr) batch = split_and_load(batch, ctx_list=ctx) batch_size = len(batch[0]) losses = [] metric_losses = [[] for _ in metrics] add_losses = [[] for _ in metrics2] with autograd.record(): for data, label, gt_mask, rpn_cls_targets, rpn_box_targets, rpn_box_masks in zip( *batch): gt_label = label[:, :, 4:5] gt_box = label[:, :, :4] cls_pred, box_pred, mask_pred, roi, samples, matches, rpn_score, rpn_box, anchors = net( data, gt_box) # losses of rpn rpn_score = rpn_score.squeeze(axis=-1) num_rpn_pos = (rpn_cls_targets >= 0).sum() rpn_loss1 = rpn_cls_loss(rpn_score, rpn_cls_targets, rpn_cls_targets >= 0) * rpn_cls_targets.size / num_rpn_pos rpn_loss2 = rpn_box_loss(rpn_box, rpn_box_targets, rpn_box_masks) * rpn_box.size / num_rpn_pos # rpn overall loss, use sum rather than average rpn_loss = rpn_loss1 + rpn_loss2 # generate targets for rcnn cls_targets, box_targets, box_masks = net.target_generator(roi, samples, matches, gt_label, gt_box) # losses of rcnn num_rcnn_pos = (cls_targets >= 0).sum() rcnn_loss1 = rcnn_cls_loss(cls_pred, cls_targets, cls_targets >= 0) * cls_targets.size / \ cls_targets.shape[0] / num_rcnn_pos rcnn_loss2 = rcnn_box_loss(box_pred, box_targets, box_masks) * box_pred.size / \ box_pred.shape[0] / num_rcnn_pos rcnn_loss = rcnn_loss1 + rcnn_loss2 # generate targets for mask mask_targets, mask_masks = net.mask_target(roi, gt_mask, matches, cls_targets) # loss of mask mask_loss = rcnn_mask_loss(mask_pred, mask_targets, mask_masks) * \ mask_targets.size / mask_targets.shape[0] / mask_masks.sum() # overall losses losses.append(rpn_loss.sum() + rcnn_loss.sum() + mask_loss.sum()) if (not args.horovod or hvd.rank() == 0): metric_losses[0].append(rpn_loss1.sum()) metric_losses[1].append(rpn_loss2.sum()) metric_losses[2].append(rcnn_loss1.sum()) metric_losses[3].append(rcnn_loss2.sum()) metric_losses[4].append(mask_loss.sum()) add_losses[0].append([[rpn_cls_targets, rpn_cls_targets >= 0], [rpn_score]]) add_losses[1].append([[rpn_box_targets, rpn_box_masks], [rpn_box]]) add_losses[2].append([[cls_targets], [cls_pred]]) add_losses[3].append([[box_targets, box_masks], [box_pred]]) add_losses[4].append([[mask_targets, mask_masks], [mask_pred]]) add_losses[5].append([[mask_targets, mask_masks], [mask_pred]]) if args.amp: with amp.scale_loss(losses, trainer) as scaled_losses: autograd.backward(scaled_losses) else: autograd.backward(losses) if (not args.horovod or hvd.rank() == 0): for metric, record in zip(metrics, metric_losses): metric.update(0, record) for metric, records in zip(metrics2, add_losses): for pred in records: metric.update(pred[0], pred[1]) trainer.step(batch_size) # update metrics if (not args.horovod or hvd.rank() == 0) and args.log_interval and not (i + 1) % args.log_interval: msg = ','.join(['{}={:.3f}'.format(*metric.get()) for metric in metrics + metrics2]) logger.info('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}'.format( epoch, i, args.log_interval * args.batch_size / (time.time() - btic), msg)) btic = time.time() # validate and save params if (not args.horovod or hvd.rank() == 0): msg = ','.join(['{}={:.3f}'.format(*metric.get()) for metric in metrics]) logger.info('[Epoch {}] Training cost: {:.3f}, {}'.format( epoch, (time.time() - tic), msg)) if not (epoch + 1) % args.val_interval: # consider reduce the frequency of validation to save time map_name, mean_ap = validate(net, val_data, ctx, eval_metric, args) val_msg = '\n'.join(['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)]) logger.info('[Epoch {}] Validation: \n{}'.format(epoch, val_msg)) current_map = float(mean_ap[-1]) else: current_map = 0. save_params(net, logger, best_map, current_map, epoch, args.save_interval, args.save_prefix)