def train_one_epoch( model, data_queue, opt, tot_steps, rank, epoch_id, world_size, enable_sublinear=False, ): sublinear_cfg = jit.SublinearMemoryConfig() if enable_sublinear else None @jit.trace(symbolic=True, sublinear_memory_config=sublinear_cfg) def propagate(): loss_dict = model(model.inputs) opt.backward(loss_dict["total_loss"]) losses = list(loss_dict.values()) return losses meter = AverageMeter(record_len=model.cfg.num_losses) time_meter = AverageMeter(record_len=2) log_interval = model.cfg.log_interval for step in range(tot_steps): adjust_learning_rate(opt, epoch_id, step, model, world_size) data_tik = time.time() mini_batch = next(data_queue) data_tok = time.time() model.inputs["image"].set_value(mini_batch["data"]) model.inputs["gt_boxes"].set_value(mini_batch["gt_boxes"]) model.inputs["im_info"].set_value(mini_batch["im_info"]) tik = time.time() opt.zero_grad() loss_list = propagate() opt.step() tok = time.time() time_meter.update([tok - tik, data_tok - data_tik]) if rank == 0: info_str = "e%d, %d/%d, lr:%f, " loss_str = ", ".join( ["{}:%f".format(loss) for loss in model.cfg.losses_keys]) time_str = ", train_time:%.3fs, data_time:%.3fs" log_info_str = info_str + loss_str + time_str meter.update([loss.numpy() for loss in loss_list]) if step % log_interval == 0: average_loss = meter.average() logger.info(log_info_str, epoch_id, step, tot_steps, opt.param_groups[0]["lr"], *average_loss, *time_meter.average()) meter.reset() time_meter.reset()
def train_one_epoch(model, data_queue, opt, gm, epoch, args): def train_func(image, im_info, gt_boxes): with gm: loss_dict = model(image=image, im_info=im_info, gt_boxes=gt_boxes) gm.backward(loss_dict["total_loss"]) loss_list = list(loss_dict.values()) opt.step().clear_grad() return loss_list meter = AverageMeter(record_len=model.cfg.num_losses) time_meter = AverageMeter(record_len=2) log_interval = model.cfg.log_interval tot_step = model.cfg.nr_images_epoch // (args.batch_size * dist.get_world_size()) for step in range(tot_step): adjust_learning_rate(opt, epoch, step, model.cfg, args) data_tik = time.time() mini_batch = next(data_queue) data_tok = time.time() tik = time.time() loss_list = train_func( image=mge.tensor(mini_batch["data"]), im_info=mge.tensor(mini_batch["im_info"]), gt_boxes=mge.tensor(mini_batch["gt_boxes"]) ) tok = time.time() time_meter.update([tok - tik, data_tok - data_tik]) if dist.get_rank() == 0: info_str = "e%d, %d/%d, lr:%f, " loss_str = ", ".join( ["{}:%f".format(loss) for loss in model.cfg.losses_keys] ) time_str = ", train_time:%.3fs, data_time:%.3fs" log_info_str = info_str + loss_str + time_str meter.update([loss.numpy() for loss in loss_list]) if step % log_interval == 0: logger.info( log_info_str, epoch, step, tot_step, opt.param_groups[0]["lr"], *meter.average(), *time_meter.average() ) meter.reset() time_meter.reset()