def main(args): # Parse device ids default_dev, *parallel_dev = parse_devices(args.devices) all_devs = parallel_dev + [default_dev] all_devs = [x.replace('gpu', '') for x in all_devs] all_devs = [int(x) for x in all_devs] nr_devs = len(all_devs) with open(args.list_val, 'r') as f: lines = f.readlines() nr_files = len(lines) if args.num_val > 0: nr_files = min(nr_files, args.num_val) nr_files_per_dev = math.ceil(nr_files / nr_devs) pbar = tqdm(total=nr_files) acc_meter = AverageMeter() intersection_meter = AverageMeter() union_meter = AverageMeter() result_queue = Queue(500) procs = [] for dev_id in range(nr_devs): start_idx = dev_id * nr_files_per_dev end_idx = min(start_idx + nr_files_per_dev, nr_files) proc = Process(target=worker, args=(args, dev_id, start_idx, end_idx, result_queue)) print('process:%d, start_idx:%d, end_idx:%d' % (dev_id, start_idx, end_idx)) proc.start() procs.append(proc) # master fetches results processed_counter = 0 while processed_counter < nr_files: if result_queue.empty(): continue (acc, pix, intersection, union) = result_queue.get() acc_meter.update(acc, pix) intersection_meter.update(intersection) union_meter.update(union) processed_counter += 1 pbar.update(1) for p in procs: p.join() iou = intersection_meter.sum / (union_meter.sum + 1e-10) for i, _iou in enumerate(iou): print('class [{}], IoU: {}'.format(i, _iou)) print('[Eval Summary]:') print('Mean IoU: {:.4}, Accuracy: {:.2f}%' .format(iou.mean(), acc_meter.average()*100)) print('Evaluation Done!')
def train(segmentation_module, iterator, optimizers, history, epoch, args): batch_time = AverageMeter() data_time = AverageMeter() names = ['object', 'part', 'scene', 'material'] ave_losses = {n: AverageMeter() for n in names} ave_metric = {n: AverageMeter() for n in names} ave_losses['total'] = AverageMeter() segmentation_module.train(not args.fix_bn) # main loop tic = time.time() for i in range(args.epoch_iters): batch_data, src_idx = next(iterator) data_time.update(time.time() - tic) segmentation_module.zero_grad() # forward pass ret = segmentation_module(batch_data) # Backward loss = ret['loss']['total'].mean() loss.backward() for optimizer in optimizers: optimizer.step() # measure elapsed time batch_time.update(time.time() - tic) tic = time.time() # measure losses for name in ret['loss'].keys(): ave_losses[name].update(ret['loss'][name].mean().item()) # measure metrics # NOTE: scene metric will be much lower than benchmark for name in ret['metric'].keys(): ave_metric[name].update(ret['metric'][name].mean().item()) # calculate accuracy, and display if i % args.disp_iter == 0: loss_info = "Loss: total {:.4f}, ".format(ave_losses['total'].average()) loss_info += ", ".join(["{} {:.2f}".format( n[0], ave_losses[n].average() if ave_losses[n].average() is not None else 0) for n in names]) acc_info = "Accuracy: " + ", ".join(["{} {:4.2f}".format( n[0], ave_metric[n].average() if ave_metric[n].average() is not None else 0) for n in names]) print('Epoch: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, ' 'LR: encoder {:.6f}, decoder {:.6f}, {}, {}' .format(epoch, i, args.epoch_iters, batch_time.average(), data_time.average(), args.running_lr_encoder, args.running_lr_decoder, acc_info, loss_info)) fractional_epoch = epoch - 1 + 1. * i / args.epoch_iters history['train']['epoch'].append(fractional_epoch) history['train']['loss'].append(loss.item()) # adjust learning rate cur_iter = i + (epoch - 1) * args.epoch_iters adjust_learning_rate(optimizers, cur_iter, args)
def evaluate(segmentation_module, loader, cfg, gpu, results_file=None): results = [] acc_meter = AverageMeter() intersection_meter = AverageMeter() union_meter = AverageMeter() time_meter = AverageMeter() segmentation_module.eval() pbar = tqdm(total=len(loader)) for batch_data in loader: # process data batch_data = batch_data[0] seg_label = as_numpy(batch_data['seg_label'][0]) img_resized_list = batch_data['img_data'] torch.cuda.synchronize() tic = time.perf_counter() with torch.no_grad(): segSize = (seg_label.shape[0], seg_label.shape[1]) scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0], segSize[1]) scores = async_copy_to(scores, gpu) for img in img_resized_list: feed_dict = batch_data.copy() feed_dict['img_data'] = img del feed_dict['img_ori'] del feed_dict['info'] feed_dict = async_copy_to(feed_dict, gpu) # forward pass scores_tmp = segmentation_module(feed_dict, segSize=segSize) scores = scores + scores_tmp / len(cfg.DATASET.imgSizes) _, pred = torch.max(scores, dim=1) pred = as_numpy(pred.squeeze(0).cpu()) torch.cuda.synchronize() time_meter.update(time.perf_counter() - tic) # calculate accuracy acc, pix = accuracy(pred, seg_label) intersection, union = intersectionAndUnion(pred, seg_label, cfg.DATASET.num_class) acc_meter.update(acc, pix) intersection_meter.update(intersection) union_meter.update(union) # visualization if cfg.VAL.visualize: visualize_result( (batch_data['img_ori'], seg_label, batch_data['info']), pred, os.path.join(cfg.DIR, 'result')) if results_file: ious = intersection / (union + 1e-10) recs = [batch_data["info"], acc] + np.column_stack( (union, ious)).ravel().tolist() results.append(recs) pbar.update(1) # summary iou = intersection_meter.sum / (union_meter.sum + 1e-10) for i, _iou in enumerate(iou): print('class [{}], IoU: {:.4f}'.format(i, _iou)) print('[Eval Summary]:') print( 'Mean IoU: {:.4f}, Accuracy: {:.2f}%, Inference Time: {:.4f}s'.format( iou.mean(), acc_meter.average() * 100, time_meter.average())) if results_file: import pandas as pd headers = ['File', 'Acc'] for i in range(len(names)): headers.extend((names[i] + '_union', names[i] + '_iou')) pd.DataFrame(results, columns=headers).to_csv(results_file, index=False)
def calc_metrics(batch_data, outputs, args): # meters sdr_mix_meter = AverageMeter() sdr_meter = AverageMeter() sir_meter = AverageMeter() sar_meter = AverageMeter() # fetch data and predictions mag_mix = batch_data['mag_mix'] phase_mix = batch_data['phase_mix'] audios = batch_data['audios'] pred_masks_ = outputs['pred_masks'] # unwarp log scale N = args.num_mix B = mag_mix.size(0) pred_masks_linear = [None for n in range(N)] for n in range(N): if args.log_freq: grid_unwarp = torch.from_numpy( warpgrid(B, args.stft_frame // 2 + 1, pred_masks_[0].size(3), warp=False)).to(args.device) pred_masks_linear[n] = F.grid_sample(pred_masks_[n], grid_unwarp) else: pred_masks_linear[n] = pred_masks_[n] # convert into numpy mag_mix = mag_mix.numpy() phase_mix = phase_mix.numpy() for n in range(N): pred_masks_linear[n] = pred_masks_linear[n].detach().cpu().numpy() # threshold if binary mask if args.binary_mask: pred_masks_linear[n] = (pred_masks_linear[n] > args.mask_thres).astype(np.float32) # loop over each sample for j in range(B): # save mixture mix_wav = istft_reconstruction(mag_mix[j, 0], phase_mix[j, 0], hop_length=args.stft_hop) # save each component preds_wav = [None for n in range(N)] for n in range(N): # Predicted audio recovery pred_mag = mag_mix[j, 0] * pred_masks_linear[n][j, 0] preds_wav[n] = istft_reconstruction(pred_mag, phase_mix[j, 0], hop_length=args.stft_hop) # separation performance computes L = preds_wav[0].shape[0] gts_wav = [None for n in range(N)] valid = True for n in range(N): gts_wav[n] = audios[n][j, 0:L].numpy() valid *= np.sum(np.abs(gts_wav[n])) > 1e-5 valid *= np.sum(np.abs(preds_wav[n])) > 1e-5 if valid: sdr, sir, sar, _ = bss_eval_sources(np.asarray(gts_wav), np.asarray(preds_wav), False) sdr_mix, _, _, _ = bss_eval_sources( np.asarray(gts_wav), np.asarray([mix_wav[0:L] for n in range(N)]), False) sdr_mix_meter.update(sdr_mix.mean()) sdr_meter.update(sdr.mean()) sir_meter.update(sir.mean()) sar_meter.update(sar.mean()) return [ sdr_mix_meter.average(), sdr_meter.average(), sir_meter.average(), sar_meter.average() ]
def evaluate(netWrapper, loader, history, epoch, args): print('Evaluating at {} epochs...'.format(epoch)) torch.set_grad_enabled(False) # remove previous viz results makedirs(args.vis, remove=True) # switch to eval mode netWrapper.eval() # initialize meters loss_meter = AverageMeter() sdr_mix_meter = AverageMeter() sdr_meter = AverageMeter() sir_meter = AverageMeter() sar_meter = AverageMeter() # initialize HTML header visualizer = HTMLVisualizer(os.path.join(args.vis, 'index.html')) header = ['Filename', 'Input Mixed Audio'] for n in range(1, args.num_mix + 1): header += [ 'Video {:d}'.format(n), 'Predicted Audio {:d}'.format(n), 'GroundTruth Audio {}'.format(n), 'Predicted Mask {}'.format(n), 'GroundTruth Mask {}'.format(n) ] header += ['Loss weighting'] visualizer.add_header(header) vis_rows = [] for i, batch_data in enumerate(loader): # forward pass err, outputs = netWrapper.forward(batch_data, args) err = err.mean() loss_meter.update(err.item()) print('[Eval] iter {}, loss: {:.4f}'.format(i, err.item())) # calculate metrics sdr_mix, sdr, sir, sar = calc_metrics(batch_data, outputs, args) sdr_mix_meter.update(sdr_mix) sdr_meter.update(sdr) sir_meter.update(sir) sar_meter.update(sar) # output visualization if len(vis_rows) < args.num_vis: output_visuals(vis_rows, batch_data, outputs, args) print('[Eval Summary] Epoch: {}, Loss: {:.4f}, ' 'SDR_mixture: {:.4f}, SDR: {:.4f}, SIR: {:.4f}, SAR: {:.4f}'.format( epoch, loss_meter.average(), sdr_mix_meter.average(), sdr_meter.average(), sir_meter.average(), sar_meter.average())) history['val']['epoch'].append(epoch) history['val']['err'].append(loss_meter.average()) history['val']['sdr'].append(sdr_meter.average()) history['val']['sir'].append(sir_meter.average()) history['val']['sar'].append(sar_meter.average()) print('Plotting html for visualization...') visualizer.add_rows(vis_rows) visualizer.write_html() # Plot figure if epoch > 0: print('Plotting figures...') plot_loss_metrics(args.ckpt, history)
def evaluate(segmentation_module, loader, cfg, gpu): acc_meter = AverageMeter() intersection_meter = AverageMeter() union_meter = AverageMeter() time_meter = AverageMeter() segmentation_module.eval() pbar = tqdm(total=len(loader)) for batch_data in loader: # process data batch_data = batch_data[0] seg_label = as_numpy(batch_data['seg_label'][0]) img_resized_list = batch_data['img_data'] torch.cuda.synchronize() tic = time.perf_counter() with torch.no_grad(): segSize = (seg_label.shape[0], seg_label.shape[1]) scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0], segSize[1]) scores = async_copy_to(scores, gpu) for img in img_resized_list: feed_dict = batch_data.copy() feed_dict['img_data'] = img del feed_dict['img_ori'] del feed_dict['info'] del feed_dict['name'] feed_dict = async_copy_to(feed_dict, gpu) # forward pass scores_tmp = segmentation_module(feed_dict, segSize=segSize) scores = scores + scores_tmp / len(cfg.DATASET.imgSizes) tmp_scores = scores if cfg.OOD.exclude_back: tmp_scores = tmp_scores[:,1:] _, pred = torch.max(scores, dim=1) pred = as_numpy(pred.squeeze(0).cpu()) torch.cuda.synchronize() time_meter.update(time.perf_counter() - tic) # calculate accuracy acc, pix = accuracy(pred, seg_label) intersection, union = intersectionAndUnion(pred, seg_label, cfg.DATASET.num_class) acc_meter.update(acc, pix) intersection_meter.update(intersection) union_meter.update(union) # visualization if cfg.VAL.visualize: visualize_result( (batch_data['img_ori'], seg_label, batch_data['info']), pred, os.path.join(cfg.TEST.result), as_numpy(scores.squeeze(0).cpu()) ) pbar.update(1) # summary iou = intersection_meter.sum / (union_meter.sum + 1e-10) for i, _iou in enumerate(iou): print('class [{}], IoU: {:.4f}'.format(i, _iou)) print('[Eval Summary]:') print('Mean IoU: {:.4f}, Accuracy: {:.2f}%, Inference Time: {:.4f}s' .format(iou.mean(), acc_meter.average()*100, time_meter.average()))
def evaluate(model, loader, gpu_mode, num_class=7): # output format res = { 'acc': 0.2, # or acc for every category, 'iou': 0.3, 'iou_mean': 0.4 } # metric meters acc_meter = AverageMeter() inter_meter = AverageMeter() union_meter = AverageMeter() # confusion_matrix = np.zeros((num_class, num_class)) for i_batch, (img, mask, _) in enumerate(loader): if gpu_mode: img = img.cuda() mask = mask.cuda() output = model(img) output = output.max(1)[1] # calculate accuracy acc = accuracy(output, mask) acc_meter.update(acc) # calculate iou(ta) # if gpu_mode: # output = output.int().cpu().detach() # mask = mask.int().cpu().detach() # seg_pred = np.array(output) # seg_gt = np.array(mask) # ignore_index = seg_gt != 255 # seg_gt = seg_gt[ignore_index] # seg_pred = seg_pred[ignore_index] # confusion_matrix += get_confusion_matrix(seg_gt, seg_pred, 7) # # pos = confusion_matrix.sum(1) # res0 = confusion_matrix.sum(0) # tp = np.diag(confusion_matrix) # # IU_array = (tp / np.maximum(1.0, pos + res0 - tp)) # calculate iou intersection, union = intersectionAndUnion(output, mask, num_class) inter_meter.update(intersection) union_meter.update(union) del output del acc # summary # iou = IU_array # iou_mean = IU_array.mean() iou = inter_meter.sum / (union_meter.sum + 1e-10) iou_mean = iou.mean() acc_mean = acc_meter.average() res['acc'] = acc_mean res['iou'] = iou res['iou_mean'] = iou_mean return res
def train(segmentation_module, iterator, optimizers, history, epoch, args): batch_time = AverageMeter() data_time = AverageMeter() names = ['object', 'part', 'scene', 'material'] ave_losses = {n: AverageMeter() for n in names} ave_metric = {n: AverageMeter() for n in names} ave_losses['total'] = AverageMeter() segmentation_module.train(not args.fix_bn) # main loop tic = time.time() for i in range(args.epoch_iters): batch_data, src_idx = next(iterator) data_time.update(time.time() - tic) segmentation_module.zero_grad() # forward pass ret = segmentation_module(batch_data) # Backward loss = ret['loss']['total'].mean() loss.backward() for optimizer in optimizers: optimizer.step() # measure elapsed time batch_time.update(time.time() - tic) tic = time.time() # measure losses for name in ret['loss'].keys(): ave_losses[name].update(ret['loss'][name].mean().item()) # measure metrics # NOTE: scene metric will be much lower than benchmark for name in ret['metric'].keys(): ave_metric[name].update(ret['metric'][name].mean().item()) # calculate accuracy, and display if i % args.disp_iter == 0: loss_info = "Loss: total {:.4f}, ".format( ave_losses['total'].average()) loss_info += ", ".join([ "{} {:.2f}".format( n[0], ave_losses[n].average() if ave_losses[n].average() is not None else 0) for n in names ]) acc_info = "Accuracy: " + ", ".join([ "{} {:4.2f}".format( n[0], ave_metric[n].average() if ave_metric[n].average() is not None else 0) for n in names ]) print('Epoch: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, ' 'LR: encoder {:.6f}, decoder {:.6f}, {}, {}'.format( epoch, i, args.epoch_iters, batch_time.average(), data_time.average(), args.running_lr_encoder, args.running_lr_decoder, acc_info, loss_info)) fractional_epoch = epoch - 1 + 1. * i / args.epoch_iters history['train']['epoch'].append(fractional_epoch) history['train']['loss'].append(loss.item()) # adjust learning rate cur_iter = i + (epoch - 1) * args.epoch_iters adjust_learning_rate(optimizers, cur_iter, args)
def mixup_train(loader, model, criterion, optimizer, epoch, use_cuda): global BEST_ACC, LR_STATE # switch to train mode if not cfg.CLS.fix_bn: model.train() else: model.eval() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() end = time.time() for batch_idx, (inputs, targets) in enumerate(loader): # adjust learning rate adjust_learning_rate(optimizer, epoch, batch=batch_idx, batch_per_epoch=len(loader)) if use_cuda: inputs, targets = inputs.cuda(), targets.cuda() # mixup inputs, targets_a, targets_b, targets_c, lam = mixup_data_triple( inputs, targets, ALPHA, use_cuda) optimizer.zero_grad() inputs, targets_a, targets_b, targets_c = Variable(inputs), Variable( targets_a), Variable(targets_b), Variable(targets_c) # measure data loading time data_time.update(time.time() - end) # forward pass: compute output outputs = model(inputs) # forward pass: compute gradient and do SGD step loss_func = mixup_criterion_triple(targets_a, targets_b, targets_c, lam) loss = loss_func(criterion, outputs) # backward loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() # measure accuracy and record loss prec1, prec5 = [0.0], [0.0] losses.update(loss.data[0], inputs.size(0)) top1.update(prec1[0], inputs.size(0)) top5.update(prec5[0], inputs.size(0)) if (batch_idx + 1) % cfg.CLS.disp_iter == 0: print( 'Training: [{}/{}][{}/{}] | Best_Acc: {:4.2f}% | Time: {:.2f} | Data: {:.2f} | ' 'LR: {:.8f} | Top1: {:.4f}% | Top5: {:.4f}% | Loss: {:.4f} | Total: {:.2f}' .format(epoch + 1, cfg.CLS.epochs, batch_idx + 1, len(loader), BEST_ACC, batch_time.average(), data_time.average(), LR_STATE, top1.avg, top5.avg, losses.avg, batch_time.sum + data_time.sum)) return (losses.avg, top1.avg)
def evaluate(nets, loader, loader_2, history, epoch, args, isVis=True): print('Evaluating at {} epochs...'.format(epoch)) loss_meter = AverageMeter() acc_meter = AverageMeter() intersection_meter = AverageMeter() union_meter = AverageMeter() loss_meter_2 = AverageMeter() acc_meter_2 = AverageMeter() intersection_meter_2 = AverageMeter() union_meter_2 = AverageMeter() # switch to eval mode for net in nets: net.eval() for i, batch_data in enumerate(loader): # forward pass torch.cuda.empty_cache() pred, recon, err = forward_with_loss(nets, batch_data, is_train=False) loss_meter.update(err.data.item()) print('[Eval] iter {}, loss: {}'.format(i, err.data.item())) # calculate accuracy acc, pix = accuracy(batch_data, pred) acc_meter.update(acc, pix) intersection, union = intersectionAndUnion(batch_data, pred, args.num_class) intersection_meter.update(intersection) union_meter.update(union) # visualization if isVis: visualize(batch_data, pred, args) visualize_recon(batch_data, recon, args) for i, batch_data in enumerate(loader_2): # forward pass torch.cuda.empty_cache() pred, recon, err = forward_with_loss(nets, batch_data, is_train=False) loss_meter_2.update(err.data.item()) print('[Eval] iter {}, loss: {}'.format(i, err.data.item())) # calculate accuracy acc, pix = accuracy(batch_data, pred) acc_meter_2.update(acc, pix) intersection, union = intersectionAndUnion(batch_data, pred, args.num_class) intersection_meter_2.update(intersection) union_meter_2.update(union) # visualization if isVis: visualize_recon(batch_data, recon, args) visualize(batch_data, pred, args) iou = intersection_meter.sum / (union_meter.sum + 1e-10) for i, _iou in enumerate(iou): print('class [{}], IoU: {}'.format(trainID2Class[i], _iou)) print('[Cityscapes Eval Summary]:') print('Epoch: {}, Loss: {}, Mean IoU: {:.4}, Accuracy: {:.2f}%'.format( epoch, loss_meter.average(), iou.mean(), acc_meter.average() * 100)) history['val']['epoch'].append(epoch) history['val']['err'].append(loss_meter.average()) history['val']['acc'].append(acc_meter.average()) history['val']['mIoU'].append(iou.mean()) iou = intersection_meter_2.sum / (union_meter_2.sum + 1e-10) for i, _iou in enumerate(iou): print('class [{}], IoU: {}'.format(trainID2Class[i], _iou)) print('[BDD Eval Summary]:') print('Epoch: {}, Loss: {}, Mean IoU: {:.4}, Accuracy: {:.2f}%'.format( epoch, loss_meter_2.average(), iou.mean(), acc_meter_2.average() * 100)) history['val_2']['epoch'].append(epoch) history['val_2']['err'].append(loss_meter_2.average()) history['val_2']['acc'].append(acc_meter_2.average()) history['val_2']['mIoU'].append(iou.mean())
def train(segmentation_module, iterator, optimizers, epoch, cfg, history=None, foveation_module=None): batch_time = AverageMeter() data_time = AverageMeter() ave_total_loss = AverageMeter() ave_acc = AverageMeter() segmentation_module.train(not cfg.TRAIN.fix_bn) if cfg.MODEL.foveation: foveation_module.train(not cfg.TRAIN.fix_bn) # main loop tic = time.time() for i in range(cfg.TRAIN.epoch_iters): # load a batch of data batch_data = next(iterator) if type(batch_data) is not list: single_gpu_mode = True batch_data['img_data'] = batch_data['img_data'][0].cuda() batch_data['seg_label'] = batch_data['seg_label'][0].cuda() batch_data = [batch_data] else: single_gpu_mode = False data_time.update(time.time() - tic) segmentation_module.zero_grad() if cfg.MODEL.foveation: foveation_module.zero_grad() # adjust learning rate non_foveation if not cfg.MODEL.foveation: cur_iter = i + (epoch - 1) * cfg.TRAIN.epoch_iters adjust_learning_rate(optimizers, cur_iter, cfg) # Foveation if cfg.MODEL.foveation: # Note by sudo_ means here is only for size estimation purpose # because batch_data is obtained by user modified DataParallel, s.t. batch_data is a list with length as len(gpus) # and each batch_data[i] is the actualy dict(batch_data) returned in dataset.TrainDataset # for ib in range(len(batch_data)): # print('img_data shape: ', batch_data[ib]['img_data'].shape) sudo_X, sudo_Y = batch_data[0]['img_data'], batch_data[0][ 'seg_label'] fov_map_scale = cfg.MODEL.fov_map_scale # NOTE: although here we use batch imresize yet in practical batch size for X = 1 sudo_X_lr = b_imresize( sudo_X, (round(sudo_X.shape[2] / fov_map_scale), round(sudo_X.shape[3] / (fov_map_scale * cfg.MODEL.patch_ap))), interp='bilinear') if cfg.TRAIN.auto_fov_location_step: cfg.TRAIN.fov_location_step = round( sudo_X.shape[2] / fov_map_scale) * round( sudo_X.shape[3] / (fov_map_scale * cfg.MODEL.patch_ap)) # foveation (crop as you go) fov_location_batch_step = 0 if cfg.TRAIN.sync_location == 'rand': # bp at each step and sync at random rand_location = random.randint(1, cfg.TRAIN.fov_location_step - 1) elif cfg.TRAIN.sync_location == 'mean_mbs': # bp and opt at each step and sync at random (last of random X_lr_cord list) with average loss rand_location = cfg.TRAIN.fov_location_step elif cfg.TRAIN.sync_location == 'none_sync': # bp and opt at each step rand_location = cfg.TRAIN.fov_location_step # mini_batch X_lr_cord = [] for xi in range(sudo_X_lr.shape[2]): for yi in range(sudo_X_lr.shape[3]): X_lr_cord.append((xi, yi)) random.shuffle(X_lr_cord) mbs = cfg.TRAIN.mini_batch_size mb_iter_count = 0 mb_idx = 0 mb_idx_count = 0 while mb_idx < len(X_lr_cord) and mb_idx_count < rand_location: # correct zero_grad https://discuss.pytorch.org/t/why-do-we-need-to-set-the-gradients-manually-to-zero-in-pytorch/4903 # https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch # https://discuss.pytorch.org/t/whats-the-difference-between-optimizer-zero-grad-vs-nn-module-zero-grad/59233 segmentation_module.zero_grad() foveation_module.zero_grad() batch_iters = rand_location cur_iter = fov_location_batch_step + (i - 1) * batch_iters + ( epoch - 1) * cfg.TRAIN.epoch_iters * batch_iters # print('original max_iter:', cfg.TRAIN.max_iters) if cfg.TRAIN.fov_scale_lr != '' or cfg.TRAIN.fov_scale_weight_decay != '': # weighted patch size normalized _ mini_batch average if mb_idx == 0: wpsn_mb = 1 else: wpsn_mb = wpsn_mb / mbs if cfg.TRAIN.sync_location != 'rand': fov_max_iters = batch_iters * cfg.TRAIN.epoch_iters * cfg.TRAIN.num_epoch if cfg.TRAIN.fov_scale_lr == 'pen_sp': # penalty small patch, smaller average patch size smaller learning rate lr_scale = float(wpsn_mb) elif cfg.TRAIN.fov_scale_lr == 'pen_lp': # penalty large patch, larger average patch size smaller learning rate lr_scale = float(1 - wpsn_mb) else: lr_scale = 1. if cfg.TRAIN.fov_scale_weight_decay == 'reg_sp': # regularise small patch, smaller average patch size larger regularisation wd_scale = float(1 - wpsn_mb) elif cfg.TRAIN.fov_scale_weight_decay == 'reg_lp': # regularise large patch, larger average patch size larger regularisation wd_scale = float(wpsn_mb) else: wd_scale = 1. if cfg.TRAIN.fov_scale_lr != '' or cfg.TRAIN.fov_scale_weight_decay != '': wpsn_mb = 0 # print('before fov_pow lr_scale={}, wd_scale={}'.format(lr_scale, wd_scale)) adjust_learning_rate(optimizers, cur_iter, cfg, lr_mbs=True, f_max_iter=fov_max_iters, lr_scale=lr_scale, wd_scale=wd_scale) if cfg.MODEL.gumbel_tau_anneal: adjust_gms_tau(cur_iter, cfg, r=1. / fov_max_iters) if cfg.TRAIN.entropy_regularisation: mbs_mean_entropy_reg = 0 xi = [] yi = [] mini_batch_sample = 0 while mini_batch_sample < mbs and mb_idx < len(X_lr_cord): xi.append(X_lr_cord[mb_idx][0]) yi.append(X_lr_cord[mb_idx][1]) mb_idx += 1 fov_location_batch_step += 1 mb_idx_count += 1 mini_batch_sample += 1 xi = tuple(xi) yi = tuple(yi) for idx in range(len(batch_data)): batch_data[idx]['cor_info'] = (xi, yi, rand_location, fov_location_batch_step) if fov_location_batch_step == rand_location: if single_gpu_mode: patch_data, F_Xlr, print_grad = foveation_module( batch_data[0]) else: patch_data, F_Xlr, print_grad = foveation_module( batch_data) else: if single_gpu_mode: patch_data, F_Xlr = foveation_module(batch_data[0]) else: patch_data, F_Xlr = foveation_module(batch_data) # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html # by set base = len(patch_bank), uniform distribution will have entropy = 1 (so absolute uncertain) if cfg.TRAIN.entropy_regularisation: # comprosed solution consider batch size != 1 F_Xlr_c = F_Xlr.clone() if cfg.MODEL.gumbel_softmax: F_Xlr_c = F_Xlr_c.exp() mean_entropy_reg = 0 for i_batch in range(F_Xlr_c.shape[0]): mean_entropy_reg += ( -F_Xlr_c[i_batch, :, xi, yi] * F_Xlr_c[i_batch, :, xi, yi].log()).sum() mbs_mean_entropy_reg += mean_entropy_reg / ( rand_location // mbs) if cfg.TRAIN.entropy_regularisation: # comprosed solution consider batch size != 1 mean_entropy = 0 for i_batch in range(F_Xlr.shape[0]): mean_entropy += (entropy( F_Xlr[i_batch, :, xi, yi].cpu().detach().numpy(), base=len( cfg.MODEL.patch_bank)).mean()) / F_Xlr.shape[0] if cfg.TRAIN.fov_scale_lr != '': print(F_Xlr.shape) pb = cfg.MODEL.patch_bank wps = torch.sum( F_Xlr[:, :, xi, yi] * torch.tensor(pb).float().unsqueeze(0).unsqueeze( -1).unsqueeze(-1).to(F_Xlr.device), dim=1).mean() wpsn = (wps - pb[0]) / (pb[-1] - pb[0]) print('wpsn: ', wpsn) wpsn_mb += wpsn # split multi gpu collected dict into list to keep DataParall work for segmentation_module # print('patch_data_img_data_shape: ', patch_data['img_data'].shape) if mb_iter_count == 0: patch_data_list = [] for idx in range(len(batch_data)): patch_data_temp = dict() patch_data_temp['img_data'] = torch.split( patch_data['img_data'], patch_data['img_data'].shape[0] // len(batch_data), dim=0)[idx] patch_data_temp['seg_label'] = torch.split( patch_data['seg_label'], patch_data['seg_label'].shape[0] // len(batch_data), dim=0)[idx] if cfg.MODEL.hard_fov_pred: patch_data_temp['hard_max_idx'] = torch.split( patch_data['hard_max_idx'], patch_data['hard_max_idx'].shape[0] // len(batch_data), dim=0)[idx] patch_data_list.append(patch_data_temp) else: for idx in range(len(batch_data)): patch_data_temp['img_data'] = torch.split( patch_data['img_data'], patch_data['img_data'].shape[0] // len(batch_data), dim=0)[idx] patch_data_temp['seg_label'] = torch.split( patch_data['seg_label'], patch_data['seg_label'].shape[0] // len(batch_data), dim=0)[idx] patch_data_list[idx]['img_data'] = torch.cat([ patch_data_list[idx]['img_data'], patch_data_temp['img_data'] ]) patch_data_list[idx]['seg_label'] = torch.cat([ patch_data_list[idx]['seg_label'], patch_data_temp['seg_label'] ]) if cfg.MODEL.hard_fov_pred: patch_data_temp['hard_max_idx'] = torch.split( patch_data['hard_max_idx'], patch_data['hard_max_idx'].shape[0] // len(batch_data), dim=0)[idx] patch_data_list[idx]['hard_max_idx'] = torch.cat([ patch_data_list[idx]['hard_max_idx'], patch_data_temp['hard_max_idx'] ]) mb_iter_count += 1 mb_iter_count = 0 # forward pass # print('[patch_data_list_img_data_shape: ]', patch_data_list[0]['img_data'].shape) if single_gpu_mode: loss, acc = segmentation_module(patch_data_list[0]) else: loss, acc = segmentation_module(patch_data_list) if cfg.MODEL.categorical: # print('log_prob_act:', patch_data['log_prob_act']) # print('ori loss:', loss) if cfg.MODEL.inv_categorical: loss = -patch_data['log_prob_act'] * loss else: loss = patch_data['log_prob_act'] * loss # print('reinforced loss:', loss) if not single_gpu_mode: loss = loss.mean() acc = acc.mean() if cfg.TRAIN.entropy_regularisation: loss += cfg.TRAIN.entropy_regularisation_weight * mbs_mean_entropy_reg if fov_location_batch_step // mbs == 1: loss_step = loss.data acc_step = acc.data else: loss_step += loss.data acc_step += acc.data if fov_location_batch_step == rand_location: loss_retain = loss elif fov_location_batch_step != cfg.TRAIN.fov_location_step: loss.backward() if cfg.TRAIN.sync_location != 'rand': for optimizer in optimizers: optimizer.step() if fov_location_batch_step == cfg.TRAIN.fov_location_step: if cfg.TRAIN.sync_location != 'none_sync': # print('iter {}: bp at random retained location {}/{}, xi={}, yi={}'.format(i, rand_location, cfg.TRAIN.fov_location_step, xi, yi)) if cfg.TRAIN.sync_location == 'mean_mbs': loss_retain.data = loss_step / ( cfg.TRAIN.fov_location_step / mbs) loss_retain.backward() else: loss.backward() for optimizer in optimizers: optimizer.step() loss_step /= (cfg.TRAIN.fov_location_step / mbs) acc_step /= (cfg.TRAIN.fov_location_step / mbs) ave_total_loss.update(loss_step.data.item()) ave_acc.update(acc_step.data.item() * 100) fov_location_batch_step = 0 if not cfg.TRAIN.auto_fov_location_step and cfg.TRAIN.sync_location == 'rand': rand_location = random.randint( 2, cfg.TRAIN.fov_location_step - 1) # print('iter {}: {}/{}/{} foveate points, xi={}, yi={}\n'.format(i, fov_location_batch_step, mb_idx, sudo_X_lr.shape[2]*sudo_X_lr.shape[3], xi, yi)) else: # forward pass loss, acc = segmentation_module(batch_data) print() loss_step = loss.mean() acc_step = acc.mean() # Backward loss_step.backward() for optimizer in optimizers: optimizer.step() # update average loss and acc ave_total_loss.update(loss_step.data.item()) ave_acc.update(acc_step.data.item() * 100) # measure elapsed time batch_time.update(time.time() - tic) tic = time.time() # calculate accuracy, and display if i % cfg.TRAIN.disp_iter == 0: if cfg.MODEL.foveation: print( 'iter {}: bp at random retained location {}/{}, xi={}, yi={}' .format(i, rand_location, cfg.TRAIN.fov_location_step, xi, yi)) print('Epoch: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, ' 'lr_encoder: {:.6f}, lr_decoder: {:.6f}, ' 'Accuracy: {:4.2f}, Loss: {:.6f}'.format( epoch, i, cfg.TRAIN.epoch_iters, batch_time.average(), data_time.average(), cfg.TRAIN.running_lr_encoder, cfg.TRAIN.running_lr_decoder, ave_acc.average(), ave_total_loss.average())) fractional_epoch = epoch - 1 + 1. * i / cfg.TRAIN.epoch_iters if history is not None: history['train']['epoch'].append(fractional_epoch) history['train']['loss'].append(ave_total_loss.average()) history['train']['acc'].append(ave_acc.average() / 100) history['train']['print_grad'] = print_grad
def evaluate(segmentation_module, loader, cfg, gpu): acc_meter = AverageMeter() intersection_meter = AverageMeter() union_meter = AverageMeter() time_meter = AverageMeter() segmentation_module.eval() pbar = tqdm(total=len(loader)) for batch_data in loader: # process data batch_data = batch_data[0] print('Info:', batch_data['info']) for key in batch_data: print(key, type(batch_data[key])) if isinstance(batch_data[key], torch.Tensor): print(batch_data[key].shape) if key == 'img_data': for i, data in enumerate(batch_data[key]): # data.requires_grad = True print(i, type(data), data.shape, data.requires_grad) seg_label = as_numpy(batch_data['seg_label'][0]) img_resized_list = batch_data['img_data'] print(seg_label.shape) torch.cuda.synchronize() tic = time.perf_counter() seg_size = (seg_label.shape[0], seg_label.shape[1]) scores = torch.zeros(1, cfg.DATASET.num_class, seg_size[0], seg_size[1]) scores = async_copy_to(scores, gpu) for img in img_resized_list: feed_dict = batch_data.copy() feed_dict['img_data'] = img del feed_dict['img_ori'] del feed_dict['info'] feed_dict = async_copy_to(feed_dict, gpu) feed_dict['img_data'].requires_grad = True print("Right before", feed_dict['img_data'].size()) # forward pass # scores_tmp = segmentation_module(feed_dict, segSize=seg_size) segmentation_module.zero_grad() loss, acc = segmentation_module(feed_dict, segSize=seg_size) loss = loss.mean() loss.backward() print(feed_dict['img_data'].grad.data.size()) scores = scores + scores_tmp / len(cfg.DATASET.imgSizes) _, pred = torch.max(scores, dim=1) pred = as_numpy(pred.squeeze(0).cpu()) torch.cuda.synchronize() time_meter.update(time.perf_counter() - tic) # calculate accuracy acc, pix = accuracy(pred, seg_label) intersection, union = intersectionAndUnion(pred, seg_label, cfg.DATASET.num_class) acc_meter.update(acc, pix) intersection_meter.update(intersection) union_meter.update(union) # visualization if cfg.VAL.visualize: visualize_result( (batch_data['img_ori'], seg_label, batch_data['info']), pred, os.path.join(cfg.DIR, 'result')) pbar.update(1) # summary iou = intersection_meter.sum / (union_meter.sum + 1e-10) for i, _iou in enumerate(iou): print('class [{}], IoU: {:.4f}'.format(i, _iou)) print('[Eval Summary]:') print( 'Mean IoU: {:.4f}, Accuracy: {:.2f}%, Inference Time: {:.4f}s'.format( iou.mean(), acc_meter.average() * 100, time_meter.average()))
def train(segmentation_module, loader_train, optimizers, history, epoch, args): batch_time = AverageMeter() data_time = AverageMeter() ave_total_loss = AverageMeter() ave_acc = AverageMeter() ave_j1 = AverageMeter() ave_j2 = AverageMeter() ave_j3 = AverageMeter() segmentation_module.train(not args.fix_bn) # main loop tic = time.time() iter_count = 0 if epoch == args.start_epoch and args.start_epoch > 1: scale_running_lr = ((1. - float(epoch - 1) / (args.num_epoch))**args.lr_pow) args.running_lr_encoder = args.lr_encoder * scale_running_lr for param_group in optimizers[0].param_groups: param_group['lr'] = args.running_lr_encoder for batch_data in loader_train: data_time.update(time.time() - tic) batch_data["image"] = batch_data["image"].cuda() segmentation_module.zero_grad() # forward pass loss, acc = segmentation_module(batch_data, epoch) loss = loss.mean() jaccard = acc[1] for j in jaccard: j = j.float().mean() acc = acc[0].float().mean() # Backward loss.backward() for optimizer in optimizers: optimizer.step() # measure elapsed time batch_time.update(time.time() - tic) tic = time.time() iter_count += args.batch_size_per_gpu # update average loss and acc ave_total_loss.update(loss.data.item()) ave_acc.update(acc.data.item() * 100) ave_j1.update(jaccard[0].data.item() * 100) ave_j2.update(jaccard[1].data.item() * 100) ave_j3.update(jaccard[2].data.item() * 100) if iter_count % (args.batch_size_per_gpu * 10) == 0: # calculate accuracy, and display if args.unet == False: print('Epoch: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, ' 'lr_encoder: {:.6f}, lr_decoder: {:.6f}, ' 'Accuracy: {:4.2f}, Loss: {:.6f}'.format( epoch, i, args.epoch_iters, batch_time.average(), data_time.average(), args.running_lr_encoder, args.running_lr_decoder, ave_acc.average(), ave_total_loss.average())) else: print( 'Epoch: [{}/{}], Iter: [{}], Time: {:.2f}, Data: {:.2f},' ' lr_unet: {:.6f}, Accuracy: {:4.2f}, Jaccard: [{:4.2f},{:4.2f},{:4.2f}], ' 'Loss: {:.6f}'.format(epoch, args.max_iters, iter_count, batch_time.average(), data_time.average(), args.running_lr_encoder, ave_acc.average(), ave_j1.average(), ave_j2.average(), ave_j3.average(), ave_total_loss.average())) #Average jaccard across classes. j_avg = (ave_j1.average() + ave_j2.average() + ave_j3.average()) / 3 #Update the training history history['train']['epoch'].append(epoch) history['train']['loss'].append(loss.data.item()) history['train']['acc'].append(acc.data.item()) history['train']['jaccard'].append(j_avg) # adjust learning rate adjust_learning_rate(optimizers, epoch, args)
def evaluate(nets, loader, history, epoch, args): print('Evaluating at {} epochs...'.format(epoch)) loss_meter = AverageMeter() acc_meter = AverageMeter() intersection_meter = AverageMeter() union_meter = AverageMeter() # switch to eval mode for net in nets: net.eval() for i, batch_data in enumerate(loader): # forward pass torch.cuda.empty_cache() pred, err = forward_with_loss(nets, batch_data, args, is_train=False) loss_meter.update(err.data[0]) print('[Eval] iter {}, loss: {}'.format(i, err.data[0])) # calculate accuracy acc, pix = accuracy(batch_data, pred) acc_meter.update(acc, pix) intersection, union = intersectionAndUnion(batch_data, pred, args.num_class) intersection_meter.update(intersection) union_meter.update(union) # visualization visualize(batch_data, pred, args) iou = intersection_meter.sum / (union_meter.sum + 1e-10) for i, _iou in enumerate(iou): print('class [{}], IoU: {}'.format(trainID2Class[i], _iou)) print('[Eval Summary]:') print('Epoch: {}, Loss: {}, Mean IoU: {:.4}, Accurarcy: {:.2f}%'.format( epoch, loss_meter.average(), iou.mean(), acc_meter.average() * 100)) history['val']['epoch'].append(epoch) history['val']['err'].append(loss_meter.average()) history['val']['acc'].append(acc_meter.average()) history['val']['mIoU'].append(iou.mean()) # Plot figure if epoch > 0: print('Plotting loss figure...') fig = plt.figure() plt.plot(np.asarray(history['train']['epoch']), np.log(np.asarray(history['train']['err'])), color='b', label='training') plt.plot(np.asarray(history['val']['epoch']), np.log(np.asarray(history['val']['err'])), color='c', label='validation') plt.legend() plt.xlabel('Epoch') plt.ylabel('Log(loss)') fig.savefig('{}/loss.png'.format(args.ckpt), dpi=200) plt.close('all') fig = plt.figure() plt.plot(history['train']['epoch'], history['train']['acc'], color='b', label='training') plt.plot(history['val']['epoch'], history['val']['acc'], color='c', label='validation') plt.legend() plt.xlabel('Epoch') plt.ylabel('Accuracy') fig.savefig('{}/accuracy.png'.format(args.ckpt), dpi=200) plt.close('all')
def evaluate(nets, loader, history, epoch, args): print('Evaluating at {} epochs...'.format(epoch)) loss_meter = AverageMeter() acc_meter = AverageMeter() # switch to eval mode for net in nets: net.eval() for i, batch_data in enumerate(loader): # forward pass pred, err = forward_with_loss(nets, batch_data, args, is_train=False) loss_meter.update(err.data[0]) print('[Eval] iter {}, loss: {}'.format(i, err.data[0])) # calculate accuracy acc, pix = accuracy(batch_data, pred) acc_meter.update(acc, pix) # visualization visualize(batch_data, pred, args) history['val']['epoch'].append(epoch) history['val']['err'].append(loss_meter.average()) history['val']['acc'].append(acc_meter.average()) print('[Eval Summary] Epoch: {}, Loss: {}, Accurarcy: {:4.2f}%'.format( epoch, loss_meter.average(), acc_meter.average() * 100)) # Plot figure if epoch > 0: print('Plotting loss figure...') fig = plt.figure() plt.plot(np.asarray(history['train']['epoch']), np.log(np.asarray(history['train']['err'])), color='b', label='training') plt.plot(np.asarray(history['val']['epoch']), np.log(np.asarray(history['val']['err'])), color='c', label='validation') plt.legend() plt.xlabel('Epoch') plt.ylabel('Log(loss)') fig.savefig('{}/loss.png'.format(args.ckpt), dpi=200) plt.close('all') fig = plt.figure() plt.plot(history['train']['epoch'], history['train']['acc'], color='b', label='training') plt.plot(history['val']['epoch'], history['val']['acc'], color='c', label='validation') plt.legend() plt.xlabel('Epoch') plt.ylabel('Accuracy') fig.savefig('{}/accuracy.png'.format(args.ckpt), dpi=200) plt.close('all')
def train(segmentation_module, loader_train, optimizers, epoch, space): adjust_learning_rate(optimizers, epoch, space['lr']) batch_time = AverageMeter() data_time = AverageMeter() ave_total_loss = AverageMeter() ave_acc = AverageMeter() ave_j1 = AverageMeter() ave_j2 = AverageMeter() ave_j3 = AverageMeter() segmentation_module.train() # main loop tic = time.time() iter_count = 0 for batch_data in loader_train: data_time.update(time.time() - tic) batch_data["image"] = batch_data["image"].cuda() batch_data["mask"] = batch_data["mask"].cuda() segmentation_module.zero_grad() # forward pass loss, acc = segmentation_module(batch_data, epoch) loss = loss.mean() jaccard = acc[1] for j in jaccard: j = j.float().mean() acc = acc[0].float().mean() # Backward loss.backward() for optimizer in optimizers: optimizer.step() # measure elapsed time batch_time.update(time.time() - tic) tic = time.time() iter_count += 4 # update average loss and acc ave_total_loss.update(loss.data.item()) ave_acc.update(acc.data.item() * 100) ave_j1.update(jaccard[0].data.item() * 100) ave_j2.update(jaccard[1].data.item() * 100) ave_j3.update(jaccard[2].data.item() * 100) if iter_count % 40 == 0: # calculate accuracy, and display print('Epoch: [{}/{}], Iter: [{}], Time: {:.2f}, Data: {:.2f},' 'Accuracy: {:4.2f}, Jaccard: [{:4.2f},{:4.2f},{:4.2f}], ' 'Loss: {:.6f}'.format(epoch, 30, iter_count, batch_time.average(), data_time.average(), ave_acc.average(), ave_j1.average(), ave_j2.average(), ave_j3.average(), ave_total_loss.average())) #Average jaccard across classes. j_avg = (ave_j1.average() + ave_j2.average() + ave_j3.average()) / 3 return ave_total_loss.average()