def train_one_epoch(): stat_dict = {} # collect statistics adjust_learning_rate(optimizer, EPOCH_CNT) net.train() # set model to training mode iou_calc = IoUCalculator(cfg) for batch_idx, batch_data in enumerate(TRAIN_DATALOADER): for key in batch_data: #ES : key for what? if type(batch_data[key]) is list: for i in range(len(batch_data[key])): batch_data[key][i] = batch_data[key][i].cuda() else: batch_data[key] = batch_data[key].cuda() # Forward pass optimizer.zero_grad() end_points = net(batch_data) loss, end_points = compute_loss(end_points, cfg) loss.backward() optimizer.step() acc, end_points = compute_acc(end_points) iou_calc.add_data(end_points) # Accumulate statistics and print out for key in end_points: if 'loss' in key or 'acc' in key or 'iou' in key: if key not in stat_dict: stat_dict[key] = 0 stat_dict[key] += end_points[key].item() batch_interval = 10 if (batch_idx + 1) % batch_interval == 0: log_string(' ---- batch: %03d ----' % (batch_idx + 1)) # TRAIN_VISUALIZER.log_scalars({key:stat_dict[key]/batch_interval for key in stat_dict}, # (EPOCH_CNT*len(TRAIN_DATALOADER)+batch_idx)*BATCH_SIZE) for key in sorted(stat_dict.keys()): log_string('mean %s: %f' % (key, stat_dict[key] / batch_interval)) stat_dict[key] = 0 mean_iou, iou_list = iou_calc.compute_iou() log_string('mean IoU:{:.1f}'.format(mean_iou * 100)) s = 'IoU:' for iou_tmp in iou_list: s += '{:5.2f} '.format(100 * iou_tmp) log_string(s)
def evaluate_one_epoch(): stat_dict = {} # collect statistics net.eval() # set model to eval mode (for bn and dp) iou_calc = IoUCalculator(cfg) for batch_idx, batch_data in enumerate(TEST_DATALOADER): for key in batch_data: if type(batch_data[key]) is list: for i in range(len(batch_data[key])): batch_data[key][i] = batch_data[key][i].cuda() else: batch_data[key] = batch_data[key].cuda() # Forward pass with torch.no_grad(): end_points = net(batch_data) loss, end_points = compute_loss(end_points, cfg) acc, end_points = compute_acc(end_points) iou_calc.add_data(end_points) # Accumulate statistics and print out for key in end_points: if 'loss' in key or 'acc' in key or 'iou' in key: if key not in stat_dict: stat_dict[key] = 0 stat_dict[key] += end_points[key].item() batch_interval = 10 if (batch_idx + 1) % batch_interval == 0: log_string(' ---- batch: %03d ----' % (batch_idx + 1)) for key in sorted(stat_dict.keys()): log_string('eval mean %s: %f' % (key, stat_dict[key] / (float(batch_idx + 1)))) mean_iou, iou_list = iou_calc.compute_iou() log_string('mean IoU:{:.1f}'.format(mean_iou * 100)) s = 'IoU:' for iou_tmp in iou_list: s += '{:5.2f} '.format(100 * iou_tmp) log_string(s)