def main(): net = DPNet().cuda() # net = nn.DataParallel(net, device_ids=[0]) print 'load snapshot \'%s\' for testing' % args['snapshot'] net.load_state_dict( torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth'))) net.eval() with torch.no_grad(): results = {} for name, root in to_test.iteritems(): precision_record, recall_record, = [ AvgMeter() for _ in range(256) ], [AvgMeter() for _ in range(256)] mae_record = AvgMeter() time_record = AvgMeter() img_list = [ os.path.splitext(f)[0] for f in os.listdir(root) if f.endswith('.jpg') ] for idx, img_name in enumerate(img_list): img_name = img_list[idx] print 'predicting for %s: %d / %d' % (name, idx + 1, len(img_list)) check_mkdir( os.path.join( ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']))) start = time.time() img = Image.open(os.path.join(root, img_name + '.jpg')).convert('RGB') img_var = Variable(img_transform(img).unsqueeze(0), volatile=True).cuda() prediction = net(img_var) W, H = img.size prediction = F.upsample_bilinear(prediction, size=(H, W)) prediction = np.array(to_pil(prediction.data.squeeze(0).cpu())) if args['crf_refine']: prediction = crf_refine(np.array(img), prediction) end = time.time() time_record.update(end - start) gt = np.array( Image.open(os.path.join(root, img_name + '.jpg')).convert('L')) precision, recall, mae = cal_precision_recall_mae( prediction, gt) for pidx, pdata in enumerate(zip(precision, recall)): p, r = pdata precision_record[pidx].update(p) recall_record[pidx].update(r) mae_record.update(mae) if args['save_results']: Image.fromarray(prediction).save( os.path.join( ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']), img_name + '.jpg')) max_fmeasure, mean_fmeasure = cal_fmeasure_both( [precord.avg for precord in precision_record], [rrecord.avg for rrecord in recall_record]) results[name] = { 'max_fmeasure': max_fmeasure, 'mae': mae_record.avg, 'mean_fmeasure': mean_fmeasure } print 'test results:' print results print 'Runing time %.6f \n' % time_record.avg with open('dpnet_result', 'a') as f: f.write('\n%s \n %s: \n' % (exp_name, exp_predict)) f.write('Runing time %.6f \n' % time_record.avg) for name, value in results.iteritems(): f.write( '%s: max_fmeasure: %.10f, mae: %.10f, mean_fmeasure: %.10f\n' % (name, value['max_fmeasure'], value['mae'], value['mean_fmeasure']))
def train(net, optimizer): curr_iter = 1 for epoch in range(args['last_epoch'] + 1, args['last_epoch'] + 1 + args['epoch_num']): loss_4_record, loss_3_record, loss_2_record, loss_1_record, \ loss_c_record, loss_b_record, loss_o_record, loss_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), \ AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter() train_iterator = tqdm(train_loader, total=len(train_loader)) for data in train_iterator: if args['poly_train']: base_lr = args['lr'] * ( 1 - float(curr_iter) / (args['epoch_num'] * len(train_loader)))**args['lr_decay'] optimizer.param_groups[0]['lr'] = 2 * base_lr optimizer.param_groups[1]['lr'] = 1 * base_lr inputs, labels, edges = data batch_size = inputs.size(0) inputs = Variable(inputs).cuda(device_ids[0]) labels = Variable(labels).cuda(device_ids[0]) edges = Variable(edges).cuda(device_ids[0]) optimizer.zero_grad() predict_4, predict_3, predict_2, predict_1, predict_c, predict_b, predict_o = net( inputs) loss_4 = L.lovasz_hinge(predict_4, labels) loss_3 = L.lovasz_hinge(predict_3, labels) loss_2 = L.lovasz_hinge(predict_2, labels) loss_1 = L.lovasz_hinge(predict_1, labels) loss_c = L.lovasz_hinge(predict_c, labels) loss_b = bce(predict_b, edges) loss_o = 2 * L.lovasz_hinge(predict_o, labels) loss = loss_4 + loss_3 + loss_2 + loss_1 + loss_c + loss_b + loss_o loss.backward() optimizer.step() loss_record.update(loss.data, batch_size) loss_4_record.update(loss_4.data, batch_size) loss_3_record.update(loss_3.data, batch_size) loss_2_record.update(loss_2.data, batch_size) loss_1_record.update(loss_1.data, batch_size) loss_c_record.update(loss_c.data, batch_size) loss_b_record.update(loss_b.data, batch_size) loss_o_record.update(loss_o.data, batch_size) if curr_iter % 50 == 0: writer.add_scalar('loss', loss, curr_iter) writer.add_scalar('loss_4', loss_4, curr_iter) writer.add_scalar('loss_3', loss_3, curr_iter) writer.add_scalar('loss_2', loss_2, curr_iter) writer.add_scalar('loss_1', loss_1, curr_iter) writer.add_scalar('loss_c', loss_c, curr_iter) writer.add_scalar('loss_b', loss_b, curr_iter) writer.add_scalar('loss_o', loss_o, curr_iter) log = '[%3d], [%5d], [%.6f], [%.5f], [L4: %.5f], [L3: %.5f], ' \ '[L2: %.5f], [L1: %.5f], [Lc: %.5f], [Lb: %.5f], [Lo: %.5f]' % \ (epoch, curr_iter, base_lr, loss_record.avg, loss_4_record.avg, loss_3_record.avg, loss_2_record.avg, loss_1_record.avg, loss_c_record.avg, loss_b_record.avg, loss_o_record.avg) train_iterator.set_description(log) open(log_path, 'a').write(log + '\n') curr_iter += 1 if epoch in args['save_point']: net.cpu() torch.save(net.module.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % epoch)) net.cuda(device_ids[0]) if epoch >= args['epoch_num']: net.cpu() torch.save(net.module.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % epoch)) print("Optimization Have Done!") return
def main(): net = R3Net(motion='', se_layer=False, attention=True, dilation=True, basic_model='resnet50') print ('load snapshot \'%s\' for testing' % args['snapshot']) net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth'), map_location='cuda:2')) net.eval() net.cuda() results = {} with torch.no_grad(): for name, root in to_test.items(): precision_record, recall_record, = [AvgMeter() for _ in range(256)], [AvgMeter() for _ in range(256)] mae_record = AvgMeter() if args['save_results']: check_mkdir(os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']))) img_list = [i_id.strip() for i_id in open(imgs_path)] # img_list = [os.path.splitext(f)[0] for f in os.listdir(root) if f.endswith('.jpg')] for idx, img_name in enumerate(img_list): print ('predicting for %s: %d / %d' % (name, idx + 1, len(img_list))) if name == 'VOS': img = Image.open(os.path.join(root, img_name + '.png')).convert('RGB') else: img = Image.open(os.path.join(root, img_name + '.jpg')).convert('RGB') shape = img.size img = img.resize(args['input_size']) img_var = Variable(img_transform(img).unsqueeze(0), volatile=True).cuda() start = time.time() prediction = net(img_var) end = time.time() print ('running time:', (end - start)) precision = to_pil(prediction.data.squeeze(0).cpu()) precision = precision.resize(shape) prediction = np.array(precision) prediction = prediction.astype('float') prediction = MaxMinNormalization(prediction, prediction.max(), prediction.min()) * 255.0 prediction = prediction.astype('uint8') if args['crf_refine']: prediction = crf_refine(np.array(img), prediction) gt = np.array(Image.open(os.path.join(gt_root, img_name + '.png')).convert('L')) precision, recall, mae = cal_precision_recall_mae(prediction, gt) for pidx, pdata in enumerate(zip(precision, recall)): p, r = pdata precision_record[pidx].update(p) recall_record[pidx].update(r) mae_record.update(mae) if args['save_results']: folder, sub_name = os.path.split(img_name) save_path = os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']), folder) if not os.path.exists(save_path): os.makedirs(save_path) Image.fromarray(prediction).save(os.path.join(save_path, sub_name + '.png')) fmeasure = cal_fmeasure([precord.avg for precord in precision_record], [rrecord.avg for rrecord in recall_record]) results[name] = {'fmeasure': fmeasure, 'mae': mae_record.avg} print ('test results:') print (results)
def train(): g = Generator(scale_factor=train_args['scale_factor']).cuda().train() g = nn.DataParallel(g, device_ids=[0, 1]) if len(train_args['g_snapshot']) > 0: print('load generator snapshot ' + train_args['g_snapshot']) g.load_state_dict( torch.load( os.path.join(train_args['ckpt_path'], train_args['g_snapshot']))) mse_criterion = nn.MSELoss().cuda() tv_criterion = TotalVariationLoss().cuda() g_mse_loss_record, g_tv_loss_record, g_loss_record, psnr_record = AvgMeter( ), AvgMeter(), AvgMeter(), AvgMeter() iter_nums = len(train_loader) if g_pretrain_args['pretrain']: g_optimizer = optim.Adam(g.parameters(), lr=g_pretrain_args['lr']) scheduler = optim.lr_scheduler.MultiStepLR( g_optimizer, milestones=[10, 20, 30, 40, 50], gamma=0.5) for epoch in range(g_pretrain_args['epoch_num']): scheduler.step() start = time.time() for i, data in enumerate(train_loader): hr_imgs, _ = data batch_size = hr_imgs.size(0) lr_imgs = Variable( torch.stack([train_lr_transform(img) for img in hr_imgs], 0)).cuda() hr_imgs = Variable(hr_imgs).cuda() g.zero_grad() gen_hr_imgs = g(lr_imgs) g_mse_loss = mse_criterion(gen_hr_imgs, hr_imgs) # g_tv_loss = tv_criterion(gen_hr_imgs) g_tv_loss = 0 g_loss = g_mse_loss + 2e-8 * g_tv_loss g_loss.backward() g_optimizer.step() g_mse_loss_record.update(g_mse_loss.item(), batch_size) # g_tv_loss_record.update(g_tv_loss.item(), batch_size) g_loss_record.update(g_loss.item(), batch_size) psnr_record.update(10 * np.log10(1 / g_mse_loss.item()), batch_size) print( '[pretrain]: [epoch %d], [iter %d / %d], [loss %.5f], [psnr %.5f]' % (epoch + 1, i + 1, iter_nums, g_loss_record.avg, psnr_record.avg)) writer.add_scalar('pretrain_g_loss', g_loss_record.avg, epoch * iter_nums + i + 1) writer.add_scalar('pretrain_psnr', psnr_record.avg, epoch * iter_nums + i + 1) torch.save( g.state_dict(), os.path.join( train_args['ckpt_path'], 'pretrain_g_epoch_%d_loss_%.5f_psnr_%.5f.pth' % (epoch + 1, g_loss_record.avg, psnr_record.avg))) end = time.time() print( '[time for last epoch: %.5f] [pretrain]: [epoch %d], [iter %d / %d], [loss %.5f], [psnr %.5f]' % (end - start, epoch + 1, i + 1, iter_nums, g_loss_record.avg, psnr_record.avg)) g_mse_loss_record.reset() psnr_record.reset() validate(g, epoch) d = Discriminator().cuda().train() d = nn.DataParallel(d, device_ids=[0, 1]) if len(train_args['d_snapshot']) > 0: print('load discriminator snapshot ' + train_args['d_snapshot']) d.load_state_dict( torch.load( os.path.join(train_args['ckpt_path'], train_args['d_snapshot']))) g_optimizer = optim.Adam(g.parameters(), lr=train_args['g_lr']) d_optimizer = optim.Adam(d.parameters(), lr=train_args['d_lr']) g_scheduler = optim.lr_scheduler.MultiStepLR(g_optimizer, milestones=[10, 20, 30, 40], gamma=0.5) d_scheduler = optim.lr_scheduler.MultiStepLR(g_optimizer, milestones=[10, 20, 30, 40], gamma=0.5) perceptual_criterion, tv_criterion = PerceptualLoss().cuda( ), TotalVariationLoss().cuda() g_mse_loss_record, g_perceptual_loss_record, g_tv_loss_record = AvgMeter( ), AvgMeter(), AvgMeter() psnr_record, g_ad_loss_record, g_loss_record, d_loss_record = AvgMeter( ), AvgMeter(), AvgMeter(), AvgMeter() for epoch in range(train_args['start_epoch'] - 1, train_args['epoch_num']): g_scheduler.step() d_scheduler.step() start = time.time() for i, data in enumerate(train_loader): hr_imgs, _ = data batch_size = hr_imgs.size(0) lr_imgs = Variable( torch.stack([train_lr_transform(img) for img in hr_imgs], 0)).cuda() hr_imgs = Variable(hr_imgs).cuda() gen_hr_imgs = g(lr_imgs) # update d d.zero_grad() # gen_hr_imgs.detach() because we don't want to update the gradients for g when d is being updated # d_ad_loss = - torch.log10(1 - d(gen_hr_imgs.detach())).mean() - torch.log10(d(hr_imgs)).mean() d_ad_loss = d(gen_hr_imgs.detach()).mean() - d(hr_imgs).mean() d_ad_loss.backward() d_optimizer.step() d_loss_record.update(d_ad_loss.item(), batch_size) for p in d.parameters(): p.data.clamp_(-train_args['c'], train_args['c']) # update g g.zero_grad() g_mse_loss = mse_criterion(gen_hr_imgs, hr_imgs) g_perceptual_loss = perceptual_criterion(gen_hr_imgs, hr_imgs) g_tv_loss = tv_criterion(gen_hr_imgs) # g_ad_loss = -torch.log10(d(gen_hr_imgs)).mean() g_ad_loss = -d(gen_hr_imgs).mean() g_loss = g_mse_loss + 0.006 * g_perceptual_loss + 0.001 * g_ad_loss + 2e-8 * g_tv_loss g_loss.backward() g_optimizer.step() g_mse_loss_record.update(g_mse_loss.item(), batch_size) g_perceptual_loss_record.update(g_perceptual_loss.item(), batch_size) g_tv_loss_record.update(g_tv_loss.item(), batch_size) psnr_record.update(10 * np.log10(1 / g_mse_loss.item()), batch_size) g_ad_loss_record.update(g_ad_loss.item(), batch_size) g_loss_record.update(g_loss.item(), batch_size) print ('[train]: [epoch %d], [iter %d / %d], [d_ad_loss %.5f], [g_ad_loss %.5f], [psnr %.5f], ' \ '[g_mse_loss %.5f], [g_perceptual_loss %.5f], [g_tv_loss %.5f] [g_loss %.5f]' % \ (epoch + 1, i + 1, iter_nums, d_loss_record.avg, g_ad_loss_record.avg, psnr_record.avg, g_mse_loss_record.avg, g_perceptual_loss_record.avg, g_tv_loss_record.avg, g_loss_record.avg)) writer.add_scalar('d_loss', d_loss_record.avg, epoch * iter_nums + i + 1) writer.add_scalar('g_mse_loss', g_mse_loss_record.avg, epoch * iter_nums + i + 1) writer.add_scalar('g_perceptual_loss', g_perceptual_loss_record.avg, epoch * iter_nums + i + 1) writer.add_scalar('g_tv_loss', g_tv_loss_record.avg, epoch * iter_nums + i + 1) writer.add_scalar('psnr', psnr_record.avg, epoch * iter_nums + i + 1) writer.add_scalar('g_ad_loss', g_ad_loss_record.avg, epoch * iter_nums + i + 1) writer.add_scalar('g_loss', g_loss_record.avg, epoch * iter_nums + i + 1) end = time.time() print ('[time for last epoch: %.5f][train]: [epoch %d], [iter %d / %d], [d_ad_loss %.5f], [g_ad_loss %.5f], [psnr %.5f], ' \ '[g_mse_loss %.5f], [g_perceptual_loss %.5f], [g_tv_loss %.5f] [g_loss %.5f]' % \ (end - start, epoch + 1, i + 1, iter_nums, d_loss_record.avg, g_ad_loss_record.avg, psnr_record.avg, g_mse_loss_record.avg, g_perceptual_loss_record.avg, g_tv_loss_record.avg, g_loss_record.avg)) d_loss_record.reset() g_mse_loss_record.reset() g_perceptual_loss_record.reset() g_tv_loss_record.reset() psnr_record.reset() g_ad_loss_record.reset() g_loss_record.reset() validate(g, epoch, d)
def train_online(net, seq_name='breakdance'): online_args = { 'iter_num': 100, 'train_batch_size': 5, 'lr': 1e-8, 'lr_decay': 0.95, 'weight_decay': 5e-4, 'momentum': 0.95, } joint_transform = joint_transforms.Compose([ joint_transforms.ImageResize(473), # joint_transforms.RandomCrop(473), # joint_transforms.RandomHorizontallyFlip(), # joint_transforms.RandomRotate(10) ]) target_transform = transforms.ToTensor() train_set = VideoFirstImageFolder(to_test['davis'], gt_root, seq_name, online_args['train_batch_size'], joint_transform, img_transform, target_transform) online_train_loader = DataLoader( train_set, batch_size=online_args['train_batch_size'], num_workers=1, shuffle=False) optimizer = optim.SGD([{ 'params': [ param for name, param in net.named_parameters() if name[-4:] == 'bias' ], 'lr': 2 * online_args['lr'] }, { 'params': [ param for name, param in net.named_parameters() if name[-4:] != 'bias' ], 'lr': online_args['lr'], 'weight_decay': online_args['weight_decay'] }], momentum=online_args['momentum']) criterion = nn.BCEWithLogitsLoss().cuda() net.train().cuda() fix_parameters(net.named_parameters()) for curr_iter in range(0, online_args['iter_num']): total_loss_record, loss0_record, loss1_record = AvgMeter(), AvgMeter( ), AvgMeter() loss2_record, loss3_record, loss4_record = AvgMeter(), AvgMeter( ), AvgMeter() for i, data in enumerate(online_train_loader): optimizer.param_groups[0]['lr'] = 2 * online_args['lr'] * ( 1 - float(curr_iter) / online_args['iter_num'])**online_args['lr_decay'] optimizer.param_groups[1]['lr'] = online_args['lr'] * ( 1 - float(curr_iter) / online_args['iter_num'])**online_args['lr_decay'] inputs, labels = data batch_size = inputs.size(0) inputs = Variable(inputs).cuda() labels = Variable(labels).cuda() optimizer.zero_grad() outputs0, outputs1, outputs2, outputs3, outputs4 = net(inputs) loss0 = criterion(outputs0, labels) loss1 = criterion(outputs1, labels.narrow(0, 1, 4)) loss2 = criterion(outputs2, labels.narrow(0, 2, 3)) loss3 = criterion(outputs3, labels.narrow(0, 3, 2)) loss4 = criterion(outputs4, labels.narrow(0, 4, 1)) total_loss = loss0 + loss1 + loss2 + loss3 + loss4 total_loss.backward() optimizer.step() total_loss_record.update(total_loss.data, batch_size) loss0_record.update(loss0.data, batch_size) loss1_record.update(loss1.data, batch_size) loss2_record.update(loss2.data, batch_size) loss3_record.update(loss3.data, batch_size) loss4_record.update(loss4.data, batch_size) log = '[iter %d], [total loss %.5f], [loss0 %.5f], [loss1 %.5f], [loss2 %.5f], [loss3 %.5f], ' \ '[loss4 %.5f], [lr %.13f]' % \ (curr_iter, total_loss_record.avg, loss0_record.avg, loss1_record.avg, loss2_record.avg, loss3_record.avg, loss4_record.avg, optimizer.param_groups[1]['lr']) print(log) return net
def main(): net = R3Net(motion='') print ('load snapshot \'%s\' for testing' % args['snapshot']) net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth'), map_location='cuda:0')) net.eval() net.cuda() results = {} with torch.no_grad(): for name, root in to_test.items(): precision_record, recall_record, = [AvgMeter() for _ in range(256)], [AvgMeter() for _ in range(256)] mae_record = AvgMeter() if args['save_results']: check_mkdir(os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']))) img_list = [i_id.strip() for i_id in open(imgs_path)] # img_list = [os.path.splitext(f)[0] for f in os.listdir(root) if f.endswith('.jpg')] for idx, img_names in enumerate(img_list): print ('predicting for %s: %d / %d' % (name, idx + 1, len(img_list))) img_seq = img_names.split(',') img_var = [] for img_name in img_seq: img = Image.open(os.path.join(root, img_name + '.jpg')).convert('RGB') shape = img.size img = img.resize(args['input_size']) img_var.append(Variable(img_transform(img).unsqueeze(0), volatile=True).cuda()) img_var = torch.cat(img_var, dim=0) prediction = net(img_var) precision = to_pil(prediction.data[-1].cpu()) precision = precision.resize(shape) prediction = np.array(precision) if args['crf_refine']: prediction = crf_refine(np.array(img), prediction) gt = np.array(Image.open(os.path.join(gt_root, img_seq[-1] + '.png')).convert('L')) precision, recall, mae = cal_precision_recall_mae(prediction, gt) for pidx, pdata in enumerate(zip(precision, recall)): p, r = pdata precision_record[pidx].update(p) recall_record[pidx].update(r) mae_record.update(mae) if args['save_results']: folder, sub_name = os.path.split(img_name) save_path = os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']), folder) if not os.path.exists(save_path): os.makedirs(save_path) Image.fromarray(prediction).save(os.path.join(save_path, sub_name + '.png')) fmeasure = cal_fmeasure([precord.avg for precord in precision_record], [rrecord.avg for rrecord in recall_record]) results[name] = {'fmeasure': fmeasure, 'mae': mae_record.avg} print ('test results:') print (results)
def train(net, optimizer): global total_epoch curr_iter = 1 start_time = time.time() for epoch in range(args['last_epoch'] + 1, args['last_epoch'] + 1 + args['epoch_num']): loss_record, loss_4_record, loss_3_record, loss_2_record, loss_1_record, = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter() train_iterator = tqdm(train_loader, total=len(train_loader)) for data in train_iterator: if args['poly_train']: base_lr = args['lr'] * (1 - float(curr_iter) / float(total_epoch)) ** args['lr_decay'] optimizer.param_groups[0]['lr'] = 2 * base_lr optimizer.param_groups[1]['lr'] = 1 * base_lr if args['poly_warmup']: if curr_iter < args['warmup_epoch']: base_lr = 1 / args['warmup_epoch'] * (1 + curr_iter) else: curr_iter = curr_iter - args['warmup_epoch'] + 1 total_epoch = total_epoch - args['warmup_epoch'] + 1 base_lr = args['lr'] * (1 - float(curr_iter) / float(total_epoch)) ** args['lr_decay'] optimizer.param_groups[0]['lr'] = 2 * base_lr optimizer.param_groups[1]['lr'] = 1 * base_lr if args['cosine_warmup']: if curr_iter < args['warmup_epoch']: base_lr = 1 / args['warmup_epoch'] * (1 + curr_iter) else: curr_iter = curr_iter - args['warmup_epoch'] + 1 total_epoch = total_epoch - args['warmup_epoch'] + 1 base_lr = args['lr'] * (1 + np.cos(np.pi * float(curr_iter) / float(total_epoch))) / 2 optimizer.param_groups[0]['lr'] = 2 * base_lr optimizer.param_groups[1]['lr'] = 1 * base_lr if args["f3_sche"]: base_lr = args['lr'] * (1 - abs((curr_iter + 1) / (total_epoch + 1) * 2 - 1)) optimizer.param_groups[0]['lr'] = 2 * base_lr optimizer.param_groups[1]['lr'] = 1 * base_lr inputs, labels = data batch_size = inputs.size(0) inputs = Variable(inputs).cuda(device_ids[0]) labels = Variable(labels).cuda(device_ids[0]) optimizer.zero_grad() predict_4, predict_3, predict_2, predict_1 = net(inputs) loss_4 = bce_iou_edge_loss(predict_4, labels) loss_3 = bce_iou_edge_loss(predict_3, labels) loss_2 = bce_iou_edge_loss(predict_2, labels) loss_1 = bce_iou_edge_loss(predict_1, labels) loss = args['w2'][0] * loss_4 + args['w2'][1] * loss_3 + args['w2'][2] * loss_2 + args['w2'][3] * loss_1 loss.backward() optimizer.step() loss_record.update(loss.data, batch_size) loss_4_record.update(loss_4.data, batch_size) loss_3_record.update(loss_3.data, batch_size) loss_2_record.update(loss_2.data, batch_size) loss_1_record.update(loss_1.data, batch_size) if curr_iter % 50 == 0: writer.add_scalar('loss', loss, curr_iter) writer.add_scalar('loss_4', loss_4, curr_iter) writer.add_scalar('loss_3', loss_3, curr_iter) writer.add_scalar('loss_2', loss_2, curr_iter) writer.add_scalar('loss_1', loss_1, curr_iter) log = '[%3d], [%6d], [%.6f], [%.5f], [%.5f], [%.5f], [%.5f], [%.5f]' % \ (epoch, curr_iter, base_lr, loss_record.avg, loss_4_record.avg, loss_3_record.avg, loss_2_record.avg, loss_1_record.avg) train_iterator.set_description(log) open(log_path, 'a').write(log + '\n') curr_iter += 1 if epoch in args['save_point']: net.cpu() torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % epoch)) net.cuda(device_ids[0]) if epoch >= args['epoch_num']: net.cpu() torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % epoch)) print("Total Training Time: {}".format(str(datetime.timedelta(seconds=int(time.time() - start_time))))) print("Optimization Have Done!") return
def train(net, optimizer): curr_iter = args['last_iter'] while True: total_loss_record, loss0_record, loss1_record, loss2_record = AvgMeter( ), AvgMeter(), AvgMeter(), AvgMeter() # loss3_record = AvgMeter() for i, data in enumerate(train_loader): optimizer.param_groups[0]['lr'] = 2 * args['lr'] * ( 1 - float(curr_iter) / args['iter_num'])**args['lr_decay'] optimizer.param_groups[1]['lr'] = args['lr'] * ( 1 - float(curr_iter) / args['iter_num'])**args['lr_decay'] inputs, labels = data if args['train_loader'] == 'video_sequence': inputs = inputs.squeeze(0) labels = labels.squeeze(0) batch_size = inputs.size(0) inputs = Variable(inputs).cuda() labels = Variable(labels).cuda() optimizer.zero_grad() outputs0, outputs1, outputs2, _, _ = net(inputs) loss0 = criterion(outputs0, labels) loss1 = criterion(outputs1, labels) loss2 = criterion(outputs2, labels) # loss3 = criterion(outputs3, labels) # loss4 = criterion(outputs4, labels) if args['distillation']: loss02 = criterion(outputs0, F.sigmoid(outputs2)) loss12 = criterion(outputs1, F.sigmoid(outputs2)) total_loss = loss0 + loss1 + loss2 + 0.5 * loss02 + 0.5 * loss12 else: total_loss = loss0 + loss1 + loss2 total_loss.backward() optimizer.step() total_loss_record.update(total_loss.data, batch_size) loss0_record.update(loss0.data, batch_size) loss1_record.update(loss1.data, batch_size) loss2_record.update(loss2.data, batch_size) # loss3_record.update(loss3.data, batch_size) # loss4_record.update(loss4.data, batch_size) curr_iter += 1 log = '[iter %d], [total loss %.5f], [loss0 %.5f], [loss1 %.5f], [loss2 %.5f] ' \ '[lr %.13f]' % \ (curr_iter, total_loss_record.avg, loss0_record.avg, loss1_record.avg, loss2_record.avg, optimizer.param_groups[1]['lr']) print(log) open(log_path, 'a').write(log + '\n') if curr_iter % args['iter_save'] == 0: print('taking snapshot ...') torch.save( net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter)) torch.save( optimizer.state_dict(), os.path.join(ckpt_path, exp_name, '%d_optim.pth' % curr_iter)) if curr_iter == args['iter_num']: torch.save( net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter)) torch.save( optimizer.state_dict(), os.path.join(ckpt_path, exp_name, '%d_optim.pth' % curr_iter)) return
def train(net, optimizer): curr_iter = 1 for epoch in range(args['last_epoch'] + 1, args['last_epoch'] + 1 + args['epoch_num']): loss_f4_record, loss_f3_record, loss_f2_record, loss_f1_record, \ loss_b4_record, loss_b3_record, loss_b2_record, loss_b1_record, \ loss_e_record, loss_fb_record, loss_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), \ AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), \ AvgMeter(), AvgMeter(), AvgMeter() train_iterator = tqdm(train_loader, total=len(train_loader)) for data in train_iterator: if args['poly_train']: base_lr = args['lr'] * ( 1 - float(curr_iter) / (args['epoch_num'] * len(train_loader)))**args['lr_decay'] optimizer.param_groups[0]['lr'] = 2 * base_lr optimizer.param_groups[1]['lr'] = 1 * base_lr inputs, labels, edges = data batch_size = inputs.size(0) inputs = Variable(inputs).cuda(device_ids[0]) labels = Variable(labels).cuda(device_ids[0]) edges = Variable(edges).cuda(device_ids[0]) optimizer.zero_grad() predict_f4, predict_f3, predict_f2, predict_f1, \ predict_b4, predict_b3, predict_b2, predict_b1, predict_e, predict_fb = net(inputs) loss_f4 = wl(predict_f4, labels) loss_f3 = wl(predict_f3, labels) loss_f2 = wl(predict_f2, labels) loss_f1 = wl(predict_f1, labels) # loss_b4 = wl(1 - torch.sigmoid(predict_b4), labels) # loss_b3 = wl(1 - torch.sigmoid(predict_b3), labels) # loss_b2 = wl(1 - torch.sigmoid(predict_b2), labels) # loss_b1 = wl(1 - torch.sigmoid(predict_b1), labels) loss_b4 = wl(1 - predict_b4, labels) loss_b3 = wl(1 - predict_b3, labels) loss_b2 = wl(1 - predict_b2, labels) loss_b1 = wl(1 - predict_b1, labels) loss_e = el(predict_e, edges) loss_fb = wl(predict_fb, labels) loss = loss_f4 + loss_f3 + loss_f2 + loss_f1 + \ loss_b4 + loss_b3 + loss_b2 + loss_b1 + loss_e + 8 * loss_fb loss.backward() optimizer.step() loss_record.update(loss.data, batch_size) loss_f4_record.update(loss_f4.data, batch_size) loss_f3_record.update(loss_f3.data, batch_size) loss_f2_record.update(loss_f2.data, batch_size) loss_f1_record.update(loss_f1.data, batch_size) loss_b4_record.update(loss_b4.data, batch_size) loss_b3_record.update(loss_b3.data, batch_size) loss_b2_record.update(loss_b2.data, batch_size) loss_b1_record.update(loss_b1.data, batch_size) loss_e_record.update(loss_e.data, batch_size) loss_fb_record.update(loss_fb.data, batch_size) if curr_iter % 50 == 0: writer.add_scalar('Total loss', loss, curr_iter) writer.add_scalar('f4 loss', loss_f4, curr_iter) writer.add_scalar('f3 loss', loss_f3, curr_iter) writer.add_scalar('f2 loss', loss_f2, curr_iter) writer.add_scalar('f1 loss', loss_f1, curr_iter) writer.add_scalar('b4 loss', loss_b4, curr_iter) writer.add_scalar('b3 loss', loss_b3, curr_iter) writer.add_scalar('b2 loss', loss_b2, curr_iter) writer.add_scalar('b1 loss', loss_b1, curr_iter) writer.add_scalar('e loss', loss_e, curr_iter) writer.add_scalar('fb loss', loss_fb, curr_iter) log = '[%3d], [f4 %.5f], [f3 %.5f], [f2 %.5f], [f1 %.5f] ' \ '[b4 %.5f], [b3 %.5f], [b2 %.5f], [b1 %.5f], [e %.5f], [fb %.5f], [lr %.6f]' % \ (epoch, loss_f4_record.avg, loss_f3_record.avg, loss_f2_record.avg, loss_f1_record.avg, loss_b4_record.avg, loss_b3_record.avg, loss_b2_record.avg, loss_b1_record.avg, loss_e_record.avg, loss_fb_record.avg, base_lr) train_iterator.set_description(log) open(log_path, 'a').write(log + '\n') curr_iter += 1 if epoch in args['save_point']: net.cpu() torch.save(net.module.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % epoch)) net.cuda(device_ids[0]) if epoch >= args['epoch_num']: net.cpu() torch.save(net.module.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % epoch)) print("Optimization Have Done!") return
def main(): net = R3Net().cuda() print('load snapshot \'%s\' for testing' % args['snapshot']) net.load_state_dict( torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth'))) net.eval() results = {} with torch.no_grad(): for name, root in to_test.iteritems(): precision_record, recall_record, = [ AvgMeter() for _ in range(256) ], [AvgMeter() for _ in range(256)] mae_record = AvgMeter() if args['save_results']: check_mkdir( os.path.join( ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']))) img_list = [ os.path.splitext(f)[0] for f in os.listdir(root) if f.endswith('.jpg') ] for idx, img_name in enumerate(img_list): print('predicting for %s: %d / %d' % (name, idx + 1, len(img_list))) img = Image.open(os.path.join(root, img_name + '.jpg')).convert('RGB') img_var = Variable(img_transform(img).unsqueeze(0), volatile=True).cuda() prediction = net(img_var) prediction = np.array(to_pil(prediction.data.squeeze(0).cpu())) if args['crf_refine']: prediction = crf_refine(np.array(img), prediction) gt = np.array( Image.open(os.path.join(root, img_name + '.png')).convert('L')) precision, recall, mae = cal_precision_recall_mae( prediction, gt) for pidx, pdata in enumerate(zip(precision, recall)): p, r = pdata precision_record[pidx].update(p) recall_record[pidx].update(r) mae_record.update(mae) if args['save_results']: Image.fromarray(prediction).save( os.path.join( ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']), img_name + '.png')) fmeasure = cal_fmeasure( [precord.avg for precord in precision_record], [rrecord.avg for rrecord in recall_record]) results[name] = {'fmeasure': fmeasure, 'mae': mae_record.avg} print('test results:') print(results)
def main(): net = SDCNet(num_classes=5).cuda() print('load snapshot \'%s\' for testing, mode:\'%s\'' % (args['snapshot'], args['test_mode'])) print(exp_name) net.load_state_dict( torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth'))) net.eval() results = {} with torch.no_grad(): for name, root in to_test.items(): print('load snapshot \'%s\' for testing %s' % (args['snapshot'], name)) test_data = pd.read_csv(root) test_set = TestFolder_joint(test_data, joint_transform, img_transform, target_transform) test_loader = DataLoader(test_set, batch_size=1, num_workers=0, shuffle=False) precision0_record, recall0_record, = [ AvgMeter() for _ in range(256) ], [AvgMeter() for _ in range(256)] precision1_record, recall1_record, = [ AvgMeter() for _ in range(256) ], [AvgMeter() for _ in range(256)] precision2_record, recall2_record, = [ AvgMeter() for _ in range(256) ], [AvgMeter() for _ in range(256)] precision3_record, recall3_record, = [ AvgMeter() for _ in range(256) ], [AvgMeter() for _ in range(256)] precision4_record, recall4_record, = [ AvgMeter() for _ in range(256) ], [AvgMeter() for _ in range(256)] precision5_record, recall5_record, = [ AvgMeter() for _ in range(256) ], [AvgMeter() for _ in range(256)] precision6_record, recall6_record, = [ AvgMeter() for _ in range(256) ], [AvgMeter() for _ in range(256)] mae0_record = AvgMeter() mae1_record = AvgMeter() mae2_record = AvgMeter() mae3_record = AvgMeter() mae4_record = AvgMeter() mae5_record = AvgMeter() mae6_record = AvgMeter() n0, n1, n2, n3, n4, n5 = 0, 0, 0, 0, 0, 0 if args['save_results']: check_mkdir( os.path.join(ckpt_path, exp_name, '%s_%s' % (name, args['snapshot']))) for i, (inputs, gt, labels, img_path) in enumerate(tqdm(test_loader)): shape = gt.size()[2:] img_var = Variable(inputs).cuda() img = np.array(to_pil(img_var.data.squeeze(0).cpu())) gt = np.array(to_pil(gt.data.squeeze(0).cpu())) sizec = labels.numpy() pred2021 = net(img_var, sizec) pred2021 = F.interpolate(pred2021, size=shape, mode='bilinear', align_corners=True) pred2021 = np.array(to_pil(pred2021.data.squeeze(0).cpu())) if labels == 0: precision1, recall1, mae1 = cal_precision_recall_mae( pred2021, gt) for pidx, pdata in enumerate(zip(precision1, recall1)): p, r = pdata precision1_record[pidx].update(p) #print('Presicion:', p, 'Recall:', r) recall1_record[pidx].update(r) mae1_record.update(mae1) n1 += 1 elif labels == 1: precision2, recall2, mae2 = cal_precision_recall_mae( pred2021, gt) for pidx, pdata in enumerate(zip(precision2, recall2)): p, r = pdata precision2_record[pidx].update(p) #print('Presicion:', p, 'Recall:', r) recall2_record[pidx].update(r) mae2_record.update(mae2) n2 += 1 elif labels == 2: precision3, recall3, mae3 = cal_precision_recall_mae( pred2021, gt) for pidx, pdata in enumerate(zip(precision3, recall3)): p, r = pdata precision3_record[pidx].update(p) #print('Presicion:', p, 'Recall:', r) recall3_record[pidx].update(r) mae3_record.update(mae3) n3 += 1 elif labels == 3: precision4, recall4, mae4 = cal_precision_recall_mae( pred2021, gt) for pidx, pdata in enumerate(zip(precision4, recall4)): p, r = pdata precision4_record[pidx].update(p) #print('Presicion:', p, 'Recall:', r) recall4_record[pidx].update(r) mae4_record.update(mae4) n4 += 1 elif labels == 4: precision5, recall5, mae5 = cal_precision_recall_mae( pred2021, gt) for pidx, pdata in enumerate(zip(precision5, recall5)): p, r = pdata precision5_record[pidx].update(p) #print('Presicion:', p, 'Recall:', r) recall5_record[pidx].update(r) mae5_record.update(mae5) n5 += 1 precision6, recall6, mae6 = cal_precision_recall_mae( pred2021, gt) for pidx, pdata in enumerate(zip(precision6, recall6)): p, r = pdata precision6_record[pidx].update(p) recall6_record[pidx].update(r) mae6_record.update(mae6) img_name = os.path.split(str(img_path))[1] img_name = os.path.splitext(img_name)[0] n0 += 1 if args['save_results']: Image.fromarray(pred2021).save( os.path.join(ckpt_path, exp_name, '%s_%s' % (name, args['snapshot']), img_name + '_2021.png')) fmeasure1 = cal_fmeasure( [precord.avg for precord in precision1_record], [rrecord.avg for rrecord in recall1_record]) fmeasure2 = cal_fmeasure( [precord.avg for precord in precision2_record], [rrecord.avg for rrecord in recall2_record]) fmeasure3 = cal_fmeasure( [precord.avg for precord in precision3_record], [rrecord.avg for rrecord in recall3_record]) fmeasure4 = cal_fmeasure( [precord.avg for precord in precision4_record], [rrecord.avg for rrecord in recall4_record]) fmeasure5 = cal_fmeasure( [precord.avg for precord in precision5_record], [rrecord.avg for rrecord in recall5_record]) fmeasure6 = cal_fmeasure( [precord.avg for precord in precision6_record], [rrecord.avg for rrecord in recall6_record]) results[name] = { 'fmeasure1': fmeasure1, 'mae1': mae1_record.avg, 'fmeasure2': fmeasure2, 'mae2': mae2_record.avg, 'fmeasure3': fmeasure3, 'mae3': mae3_record.avg, 'fmeasure4': fmeasure4, 'mae4': mae4_record.avg, 'fmeasure5': fmeasure5, 'mae5': mae5_record.avg, 'fmeasure6': fmeasure6, 'mae6': mae6_record.avg } print('test results:') print('[fmeasure1 %.3f], [mae1 %.4f], [class1 %.0f]\n'\ '[fmeasure2 %.3f], [mae2 %.4f], [class2 %.0f]\n'\ '[fmeasure3 %.3f], [mae3 %.4f], [class3 %.0f]\n'\ '[fmeasure4 %.3f], [mae4 %.4f], [class4 %.0f]\n'\ '[fmeasure5 %.3f], [mae5 %.4f], [class5 %.0f]\n'\ '[fmeasure6 %.3f], [mae6 %.4f], [all %.0f]\n'%\ (fmeasure1, mae1_record.avg, n1, fmeasure2, mae2_record.avg, n2, fmeasure3, mae3_record.avg, n3, fmeasure4, mae4_record.avg, n4, fmeasure5, mae5_record.avg, n5, fmeasure6, mae6_record.avg, n0))
def train(exp_name): net = AADFNet().cuda().train() net = nn.DataParallel(net, device_ids=[0, 1]) optimizer = optim.SGD([{ 'params': [ param for name, param in net.named_parameters() if name[-4:] == 'bias' ], 'lr': 2 * args['lr'] }, { 'params': [ param for name, param in net.named_parameters() if name[-4:] != 'bias' ], 'lr': args['lr'], 'weight_decay': args['weight_decay'] }], momentum=args['momentum']) if len(args['snapshot']) > 0: print('training resumes from ' + args['snapshot']) net.load_state_dict( torch.load( os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth'))) optimizer.load_state_dict( torch.load( os.path.join(ckpt_path, exp_name, args['snapshot'] + '_optim.pth'))) optimizer.param_groups[0]['lr'] = 2 * args['lr'] optimizer.param_groups[1]['lr'] = args['lr'] check_mkdir(ckpt_path) check_mkdir(os.path.join(ckpt_path, exp_name)) log_path = os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt') open(log_path, 'w').write(str(args) + '\n\n') print 'start to train' curr_iter = args['last_iter'] while True: total_loss_record, loss1_record, loss2_record = AvgMeter(), AvgMeter( ), AvgMeter() loss3_record, loss4_record = AvgMeter(), AvgMeter() loss2_2_record, loss3_2_record, loss4_2_record = AvgMeter(), AvgMeter( ), AvgMeter() loss44_record, loss43_record, loss42_record, loss41_record = AvgMeter( ), AvgMeter(), AvgMeter(), AvgMeter() loss34_record, loss33_record, loss32_record, loss31_record = AvgMeter( ), AvgMeter(), AvgMeter(), AvgMeter() loss24_record, loss23_record, loss22_record, loss21_record = AvgMeter( ), AvgMeter(), AvgMeter(), AvgMeter() loss14_record, loss13_record, loss12_record, loss11_record = AvgMeter( ), AvgMeter(), AvgMeter(), AvgMeter() for i, data in enumerate(train_loader): optimizer.param_groups[0]['lr'] = 2 * args['lr'] * ( 1 - float(curr_iter) / args['iter_num'])**args['lr_decay'] optimizer.param_groups[1]['lr'] = args['lr'] * ( 1 - float(curr_iter) / args['iter_num'])**args['lr_decay'] inputs, labels = data batch_size = inputs.size(0) inputs = Variable(inputs).cuda() labels = Variable(labels).cuda() optimizer.zero_grad() outputs4_2, outputs3_2, outputs2_2, outputs1, outputs2, outputs3, outputs4, \ predict41, predict42, predict43, predict44, \ predict31, predict32, predict33, predict34, \ predict21, predict22, predict23, predict24, \ predict11, predict12, predict13, predict14 = net(inputs) loss1 = criterion(outputs1, labels) loss2 = criterion(outputs2, labels) loss3 = criterion(outputs3, labels) loss4 = criterion(outputs4, labels) loss2_2 = criterion(outputs2_2, labels) loss3_2 = criterion(outputs3_2, labels) loss4_2 = criterion(outputs4_2, labels) loss44 = criterion(predict44, labels) loss43 = criterion(predict43, labels) loss42 = criterion(predict42, labels) loss41 = criterion(predict41, labels) loss34 = criterion(predict34, labels) loss33 = criterion(predict33, labels) loss32 = criterion(predict32, labels) loss31 = criterion(predict31, labels) loss24 = criterion(predict24, labels) loss23 = criterion(predict23, labels) loss22 = criterion(predict22, labels) loss21 = criterion(predict21, labels) loss14 = criterion(predict14, labels) loss13 = criterion(predict13, labels) loss12 = criterion(predict12, labels) loss11 = criterion(predict11, labels) total_loss = loss1 + loss2 + loss3 + loss4 + loss2_2 + loss3_2 + loss4_2 \ + (loss44 + loss43 + loss42 + loss41)/10 \ + (loss34 + loss33 + loss32 + loss31)/10 \ + (loss24 + loss23 + loss22 + loss21)/10 \ + (loss14 + loss13 + loss12 + loss11)/10 total_loss = loss1 + loss2 + loss3 + loss4 total_loss.backward() optimizer.step() total_loss_record.update(total_loss.item(), batch_size) loss1_record.update(loss1.item(), batch_size) loss2_record.update(loss2.item(), batch_size) loss3_record.update(loss3.item(), batch_size) loss4_record.update(loss4.item(), batch_size) loss2_2_record.update(loss2_2.item(), batch_size) loss3_2_record.update(loss3_2.item(), batch_size) loss4_2_record.update(loss4_2.item(), batch_size) loss44_record.update(loss44.item(), batch_size) loss43_record.update(loss43.item(), batch_size) loss42_record.update(loss42.item(), batch_size) loss41_record.update(loss41.item(), batch_size) loss34_record.update(loss34.item(), batch_size) loss33_record.update(loss33.item(), batch_size) loss32_record.update(loss32.item(), batch_size) loss31_record.update(loss31.item(), batch_size) loss24_record.update(loss24.item(), batch_size) loss23_record.update(loss23.item(), batch_size) loss22_record.update(loss22.item(), batch_size) loss21_record.update(loss21.item(), batch_size) loss14_record.update(loss14.item(), batch_size) loss13_record.update(loss13.item(), batch_size) loss12_record.update(loss12.item(), batch_size) loss11_record.update(loss11.item(), batch_size) curr_iter += 1 log = '[iter %d], [total loss %.5f], ' \ '[loss4_2 %.5f], [loss3_2 %.5f], [loss2_2 %.5f], [loss1 %.5f], ' \ '[loss2 %.5f], [loss3 %.5f], [loss4 %.5f], ' \ '[loss44 %.5f], [loss43 %.5f], [loss42 %.5f], [loss41 %.5f], ' \ '[loss34 %.5f], [loss33 %.5f], [loss32 %.5f], [loss31 %.5f], ' \ '[loss24 %.5f], [loss23 %.5f], [loss22 %.5f], [loss21 %.5f], ' \ '[loss14 %.5f], [loss13 %.5f], [loss12 %.5f], [loss11 %.5f], ' \ '[lr %.13f]' % \ (curr_iter, total_loss_record.avg, loss4_2_record.avg, loss3_2_record.avg, loss2_2_record.avg, loss1_record.avg, loss2_record.avg, loss3_record.avg, loss4_record.avg, loss44_record.avg, loss43_record.avg, loss42_record.avg, loss41_record.avg, loss34_record.avg, loss33_record.avg, loss32_record.avg, loss31_record.avg, loss24_record.avg, loss23_record.avg, loss22_record.avg, loss21_record.avg, loss14_record.avg, loss13_record.avg, loss12_record.avg, loss11_record.avg, optimizer.param_groups[1]['lr']) print log open(log_path, 'a').write(log + '\n') if curr_iter == args['iter_num']: torch.save( net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter)) torch.save( optimizer.state_dict(), os.path.join(ckpt_path, exp_name, '%d_optim.pth' % curr_iter)) return
def train(net, optimizer): curr_iter = args['last_iter'] while True: train_loss_record = AvgMeter() train_net_loss_record = AvgMeter() for i, data in enumerate(train_loader): optimizer.param_groups[0]['lr'] = 2 * args['lr'] * ( 1 - float(curr_iter) / args['iter_num'])**args['lr_decay'] optimizer.param_groups[1]['lr'] = args['lr'] * ( 1 - float(curr_iter) / args['iter_num'])**args['lr_decay'] inputs, gts, dps = data batch_size = inputs.size(0) inputs = Variable(inputs).cuda() gts = Variable(gts).cuda() dps = Variable(dps).cuda() optimizer.zero_grad() result = net(inputs) loss_net = criterion(result, gts) loss = loss_net loss.backward() optimizer.step() # for n, p in net.named_parameters(): # if n[-5:] == 'alpha': # print(p.grad.data) # print(p.data) train_loss_record.update(loss.data, batch_size) train_net_loss_record.update(loss_net.data, batch_size) curr_iter += 1 log = '[iter %d], [train loss %.5f], [lr %.13f], [loss_net %.5f]' % \ (curr_iter, train_loss_record.avg, optimizer.param_groups[1]['lr'], train_net_loss_record.avg) print(log) open(log_path, 'a').write(log + '\n') if (curr_iter + 1) % args['val_freq'] == 0: validate(net, curr_iter, optimizer) if (curr_iter + 1) % args['snapshot_epochs'] == 0: torch.save( net.state_dict(), os.path.join(ckpt_path, exp_name, ('%d.pth' % (curr_iter + 1)))) torch.save( optimizer.state_dict(), os.path.join(ckpt_path, exp_name, ('%d_optim.pth' % (curr_iter + 1)))) if curr_iter > args['iter_num']: return
def train(net, optimizer): curr_iter = args['last_iter'] for e in range(args["epoch"]): total_loss_record, loss0_record, loss1_record, loss2_record = AvgMeter( ), AvgMeter(), AvgMeter(), AvgMeter() loss3_record, loss4_record, loss5_record, loss6_record = AvgMeter( ), AvgMeter(), AvgMeter(), AvgMeter() print "epoch", e for i, data in enumerate(train_loader): #optimizer.param_groups[0]['lr'] = 2 * args['lr'] * # ** args['lr_decay'] #optimizer.param_groups[1]['lr'] = args['lr'] * (1 - float(curr_iter) / args['iter_num'] # ) ** args['lr_decay'] inputs, labels = data batch_size = inputs.size(0) inputs = Variable(inputs).cuda() labels = Variable(labels).cuda() print(inputs.size()) optimizer.zero_grad() outputs0, outputs1, outputs2, outputs3, outputs4, outputs5, outputs6 = net( inputs) loss0 = criterion(outputs0, labels) loss1 = criterion(outputs1, labels) loss2 = criterion(outputs2, labels) loss3 = criterion(outputs3, labels) loss4 = criterion(outputs4, labels) loss5 = criterion(outputs5, labels) loss6 = criterion(outputs6, labels) total_loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6 total_loss.backward() optimizer.step() total_loss_record.update(total_loss.data[0], batch_size) loss0_record.update(loss0.data[0], batch_size) loss1_record.update(loss1.data[0], batch_size) loss2_record.update(loss2.data[0], batch_size) loss3_record.update(loss3.data[0], batch_size) loss4_record.update(loss4.data[0], batch_size) loss5_record.update(loss5.data[0], batch_size) loss6_record.update(loss6.data[0], batch_size) curr_iter += 1 log = '[iter %d], [total loss %.5f], [loss0 %.5f], [loss1 %.5f], [loss2 %.5f], [loss3 %.5f], ' \ '[loss4 %.5f], [loss5 %.5f], [loss6 %.5f], [lr %.13f]' % \ (curr_iter, total_loss_record.avg, loss0_record.avg, loss1_record.avg, loss2_record.avg, loss3_record.avg, loss4_record.avg, loss5_record.avg, loss6_record.avg, optimizer.param_groups[1]['lr']) print log open(log_path, 'a').write(log + '\n') if curr_iter % args['iter_num'] == 0: torch.save( net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter)) torch.save( optimizer.state_dict(), os.path.join(ckpt_path, exp_name, '%d_optim.pth' % curr_iter))
def train(net, optimizer): curr_iter = args['last_iter'] while True: total_loss_record, loss0_record, loss1_record = AvgMeter(), AvgMeter( ), AvgMeter() loss2_record, loss3_record, loss4_record, loss5_record = AvgMeter( ), AvgMeter(), AvgMeter(), AvgMeter() if args['isTriplet']: loss_triplet_record = AvgMeter() for i, data in enumerate(train_loader): optimizer.param_groups[0]['lr'] = 2 * args['lr'] * ( 1 - float(curr_iter) / args['iter_num'])**args['lr_decay'] optimizer.param_groups[1]['lr'] = args['lr'] * ( 1 - float(curr_iter) / args['iter_num'])**args['lr_decay'] inputs, labels = data if args['train_loader'] == 'video_sequence': inputs = inputs.squeeze(0) labels = labels.squeeze(0) batch_size = inputs.size(0) inputs = Variable(inputs).cuda() labels = Variable(labels).cuda() optimizer.zero_grad() if args['isTriplet']: outputs0, outputs1, outputs2, outputs3, outputs4, outputs5, outputs_triplet = net( inputs) else: outputs0, outputs1, outputs2, outputs3, outputs4, outputs5 = net( inputs) loss0 = criterion(outputs0, labels) loss1 = criterion(outputs1, labels.narrow(0, 1, 5)) loss2 = criterion(outputs2, labels.narrow(0, 2, 4)) loss3 = criterion(outputs3, labels.narrow(0, 3, 3)) loss4 = criterion(outputs4, labels.narrow(0, 4, 2)) loss5 = criterion(outputs5, labels.narrow(0, 5, 1)) if args['L2']: loss0 = loss0 + 0.1 * criterion_l2( torch.relu(outputs0) / torch.max(outputs0), labels) loss1 = loss1 + 0.1 * criterion_l2( torch.relu(outputs1) / torch.max(outputs1), labels.narrow(0, 1, 4)) loss2 = loss2 + 0.1 * criterion_l2( torch.relu(outputs2) / torch.max(outputs2), labels.narrow(0, 2, 3)) loss3 = loss3 + 0.1 * criterion_l2( torch.relu(outputs3) / torch.max(outputs3), labels.narrow(0, 3, 2)) loss4 = loss4 + 0.1 * criterion_l2( torch.relu(outputs4) / torch.max(outputs4), labels.narrow(0, 4, 1)) if args['dice']: loss0 = loss0 + 0.5 * criterion_dice(outputs0, labels) loss1 = loss1 + 0.5 * criterion_dice(outputs1, labels.narrow(0, 1, 5)) loss2 = loss2 + 0.5 * criterion_dice(outputs2, labels.narrow(0, 2, 4)) loss3 = loss3 + 0.5 * criterion_dice(outputs3, labels.narrow(0, 3, 3)) loss4 = loss4 + 0.5 * criterion_dice(outputs4, labels.narrow(0, 4, 2)) loss5 = loss4 + 0.5 * criterion_dice(outputs5, labels.narrow(0, 5, 1)) if args['isTriplet']: loss_triplet = criterion_triplet(outputs_triplet[0], outputs_triplet[1], outputs_triplet[2]) total_loss = loss0 + loss1 + loss2 + loss3 + loss4 + 0.2 * loss_triplet total_loss.backward() optimizer.step() total_loss_record.update(total_loss.data, batch_size) loss0_record.update(loss0.data, batch_size) loss1_record.update(loss1.data, batch_size) loss2_record.update(loss2.data, batch_size) loss3_record.update(loss3.data, batch_size) loss4_record.update(loss4.data, batch_size) loss_triplet_record.update(loss_triplet.data, batch_size) else: total_loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 total_loss.backward() optimizer.step() total_loss_record.update(total_loss.data, batch_size) loss0_record.update(loss0.data, batch_size) loss1_record.update(loss1.data, batch_size) loss2_record.update(loss2.data, batch_size) loss3_record.update(loss3.data, batch_size) loss4_record.update(loss4.data, batch_size) loss5_record.update(loss5.data, batch_size) curr_iter += 1 if args['isTriplet']: log = '[iter %d], [total loss %.5f], [loss0 %.5f], [loss1 %.5f], [loss2 %.5f], [loss3 %.5f], ' \ '[loss4 %.5f], [loss_triplet %.5f], [lr %.13f] ' % \ (curr_iter, total_loss_record.avg, loss0_record.avg, loss1_record.avg, loss2_record.avg, loss3_record.avg, loss4_record.avg, loss_triplet_record.avg, optimizer.param_groups[1]['lr']) else: log = '[iter %d], [total loss %.5f], [loss0 %.5f], [loss1 %.5f], [loss2 %.5f], [loss3 %.5f], ' \ '[loss4 %.5f], [loss5 %.5f], [lr %.13f]' % \ (curr_iter, total_loss_record.avg, loss0_record.avg, loss1_record.avg, loss2_record.avg, loss3_record.avg, loss4_record.avg, loss5_record.avg, optimizer.param_groups[1]['lr']) print(log) open(log_path, 'a').write(log + '\n') if curr_iter % args['iter_save'] == 0: print('taking snapshot ...') torch.save( net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter)) torch.save( optimizer.state_dict(), os.path.join(ckpt_path, exp_name, '%d_optim.pth' % curr_iter)) if curr_iter == args['iter_num']: torch.save( net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter)) torch.save( optimizer.state_dict(), os.path.join(ckpt_path, exp_name, '%d_optim.pth' % curr_iter)) return
def train(net, optimizer): global best_ber curr_iter = 1 start_time = time.time() for epoch in range(args['last_epoch'] + 1, args['last_epoch'] + 1 + args['epoch_num']): loss_4_record, loss_3_record, loss_2_record, loss_1_record, \ loss_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter() train_iterator = tqdm(train_loader, total=len(train_loader)) for data in train_iterator: if args['poly_train']: base_lr = args['lr'] * (1 - float(curr_iter) / float(total_epoch))**args['lr_decay'] optimizer.param_groups[0]['lr'] = 2 * base_lr optimizer.param_groups[1]['lr'] = 1 * base_lr inputs, labels = data batch_size = inputs.size(0) inputs = Variable(inputs).cuda(device_ids[0]) labels = Variable(labels).cuda(device_ids[0]) optimizer.zero_grad() predict_4, predict_3, predict_2, predict_1 = net(inputs) loss_4 = L.lovasz_hinge(predict_4, labels) loss_3 = L.lovasz_hinge(predict_3, labels) loss_2 = L.lovasz_hinge(predict_2, labels) loss_1 = L.lovasz_hinge(predict_1, labels) loss = loss_4 + loss_3 + loss_2 + loss_1 loss.backward() optimizer.step() loss_record.update(loss.data, batch_size) loss_4_record.update(loss_4.data, batch_size) loss_3_record.update(loss_3.data, batch_size) loss_2_record.update(loss_2.data, batch_size) loss_1_record.update(loss_1.data, batch_size) if curr_iter % 50 == 0: writer.add_scalar('loss', loss, curr_iter) writer.add_scalar('loss_4', loss_4, curr_iter) writer.add_scalar('loss_3', loss_3, curr_iter) writer.add_scalar('loss_2', loss_2, curr_iter) writer.add_scalar('loss_1', loss_1, curr_iter) log = '[%3d], [%6d], [%.6f], [%.5f], [L4: %.5f], [L3: %.5f], [L2: %.5f], [L1: %.5f]' % \ (epoch, curr_iter, base_lr, loss_record.avg, loss_4_record.avg, loss_3_record.avg, loss_2_record.avg, loss_1_record.avg) train_iterator.set_description(log) open(log_path, 'a').write(log + '\n') curr_iter += 1 if epoch in args['save_point']: net.cpu() torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % epoch)) net.cuda(device_ids[0]) if epoch >= args['epoch_thres'] and epoch % 5 == 0: ber = test(net) print("mean ber of %d epoch is %.5f" % (epoch, ber)) if ber < best_ber: net.cpu() torch.save( net.state_dict(), os.path.join(ckpt_path, exp_name, 'epoch_%d_ber_%.2f.pth' % (epoch, ber))) print("The optimized epoch is %04d" % epoch) net = net.cuda(device_ids[0]).train() if epoch >= args['epoch_num']: net.cpu() torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % epoch)) print("Total Training Time: {}".format( str(datetime.timedelta(seconds=int(time.time() - start_time))))) print(exp_name) print("Optimization Have Done!") return
def main(): # net = R3Net(motion='', se_layer=False, dilation=False, basic_model='resnet50') net = SNet(cfg=None) print('load snapshot \'%s\' for testing' % args['snapshot']) # net.load_state_dict(torch.load('pretrained/R2Net.pth', map_location='cuda:2')) # net = load_part_of_model2(net, 'pretrained/R2Net.pth', device_id=2) net.load_state_dict( torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth'), map_location='cuda:2')) net.eval() net.cuda() results = {} with torch.no_grad(): for name, root in to_test.items(): precision_record, recall_record, = [ AvgMeter() for _ in range(256) ], [AvgMeter() for _ in range(256)] mae_record = AvgMeter() if args['save_results']: check_mkdir( os.path.join( ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']))) img_list = [i_id.strip() for i_id in open(imgs_path)] video = '' pre_predict = None for idx, img_name in enumerate(img_list): print('predicting for %s: %d / %d' % (name, idx + 1, len(img_list))) print(img_name) if video != img_name.split('/')[0]: video = img_name.split('/')[0] if name == 'VOS' or name == 'DAVSOD': img = Image.open(os.path.join(root, img_name + '.png')).convert('RGB') else: img = Image.open(os.path.join(root, img_name + '.jpg')).convert('RGB') shape = img.size img = img.resize(args['input_size']) img_var = Variable(img_transform(img).unsqueeze(0), volatile=True).cuda() start = time.time() if args['model'] == 'BASNet': prediction, _, prediction2, _, _, _, _, _ = net( img_var) prediction = torch.sigmoid(prediction) elif args['model'] == 'R3Net': prediction = net(img_var) elif args['model'] == 'DSSNet': select = [1, 2, 3, 6] prediction = net(img_var) prediction = torch.mean(torch.cat( [torch.sigmoid(prediction[i]) for i in select], dim=1), dim=1, keepdim=True) elif args['model'] == 'CPD': prediction2, prediction = net(img_var) prediction = torch.sigmoid(prediction) elif args['model'] == 'RAS': prediction, _, _, _, _ = net(img_var) prediction = torch.sigmoid(prediction) elif args['model'] == 'PoolNet': prediction = net(img_var) prediction = torch.sigmoid(prediction) elif args['model'] == 'F3Net': prediction2, prediction, _, _, _, _ = net(img_var) prediction = torch.sigmoid(prediction) elif args['model'] == 'R2Net': _, _, _, _, _, prediction = net(img_var) prediction = torch.sigmoid(prediction) end = time.time() pre_predict = prediction print('running time:', (end - start)) else: if name == 'VOS' or name == 'DAVSOD': img = Image.open(os.path.join(root, img_name + '.png')).convert('RGB') else: img = Image.open(os.path.join(root, img_name + '.jpg')).convert('RGB') shape = img.size img = img.resize(args['input_size']) img_var = Variable(img_transform(img).unsqueeze(0), volatile=True).cuda() start = time.time() _, prediction, _, _, _, _ = net(img_var) end = time.time() print('running time:', (end - start)) pre_predict = prediction # e = Erosion2d(1, 1, 5, soft_max=False).cuda() # prediction2 = e(prediction) # # precision2 = to_pil(prediction2.data.squeeze(0).cpu()) # precision2 = prediction2.data.squeeze(0).cpu().numpy() # precision2 = precision2.resize(shape) # prediction2 = np.array(precision2) # prediction2 = prediction2.astype('float') precision = to_pil(prediction.data.squeeze(0).cpu()) precision = precision.resize(shape) prediction = np.array(precision) prediction = prediction.astype('float') # plt.style.use('classic') # plt.subplot(1, 2, 1) # plt.imshow(prediction) # plt.subplot(1, 2, 2) # plt.imshow(precision2[0]) # plt.show() prediction = MaxMinNormalization(prediction, prediction.max(), prediction.min()) * 255.0 prediction = prediction.astype('uint8') # if args['crf_refine']: # prediction = crf_refine(np.array(img), prediction) gt = np.array( Image.open(os.path.join(gt_root, img_name + '.png')).convert('L')) precision, recall, mae = cal_precision_recall_mae( prediction, gt) for pidx, pdata in enumerate(zip(precision, recall)): p, r = pdata precision_record[pidx].update(p) recall_record[pidx].update(r) mae_record.update(mae) if args['save_results']: folder, sub_name = os.path.split(img_name) save_path = os.path.join( ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']), folder) if not os.path.exists(save_path): os.makedirs(save_path) Image.fromarray(prediction).save( os.path.join(save_path, sub_name + '.png')) fmeasure = cal_fmeasure( [precord.avg for precord in precision_record], [rrecord.avg for rrecord in recall_record]) results[name] = {'fmeasure': fmeasure, 'mae': mae_record.avg} print('test results:') print(results)
def train(net, optimizer): curr_iter = 1 for epoch in range(args['last_epoch'] + 1, args['last_epoch'] + 1 + args['epoch_num']): loss_record, loss_f_record, loss_b_record, loss_o_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter() train_iterator = tqdm(train_loader, total=len(train_loader)) for data in train_iterator: if args['poly_train']: base_lr = args['lr'] * (1 - float(curr_iter) / (args['epoch_num'] * len(train_loader))) ** args[ 'lr_decay'] optimizer.param_groups[0]['lr'] = 2 * base_lr optimizer.param_groups[1]['lr'] = 1 * base_lr inputs, labels = data batch_size = inputs.size(0) inputs = Variable(inputs).cuda(device_ids[0]) labels = Variable(labels).cuda(device_ids[0]) optimizer.zero_grad() predict_f, predict_b, predict_o = net(inputs) loss_f = L.lovasz_hinge(predict_f, labels) loss_b = L.lovasz_hinge(predict_b, 1 - labels) loss_o = 2 * L.lovasz_hinge(predict_o, labels) loss = loss_f + loss_b + loss_o loss.backward() optimizer.step() loss_record.update(loss.data, batch_size) loss_f_record.update(loss_f.data, batch_size) loss_b_record.update(loss_b.data, batch_size) loss_o_record.update(loss_o.data, batch_size) if curr_iter % 50 == 0: writer.add_scalar('loss', loss, curr_iter) writer.add_scalar('loss_f', loss_f, curr_iter) writer.add_scalar('loss_b', loss_b, curr_iter) writer.add_scalar('loss_o', loss_o, curr_iter) log = '[Epoch: %2d], [Iter: %5d], [%.7f], [Sum: %.5f], [Lf: %.5f], [Lb: %.5f], [Lo: %.5f]' % \ (epoch, curr_iter, base_lr, loss_record.avg, loss_f_record.avg, loss_b_record.avg, loss_o_record.avg) train_iterator.set_description(log) open(log_path, 'a').write(log + '\n') curr_iter += 1 if epoch in args['save_point']: net.cpu() torch.save(net.module.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % epoch)) net.cuda(device_ids[0]) if epoch >= args['epoch_num']: net.cpu() torch.save(net.module.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % epoch)) print("Optimization Have Done!") return
def train(net, optimizer): net.print_network() curr_iter = args['last_iter'] while True: train_loss_record, loss_fuse_record, loss1_h2l_record = AvgMeter( ), AvgMeter(), AvgMeter() loss2_h2l_record, loss3_h2l_record, loss4_h2l_record = AvgMeter( ), AvgMeter(), AvgMeter() loss1_l2h_record, loss2_l2h_record, loss3_l2h_record = AvgMeter( ), AvgMeter(), AvgMeter() loss4_l2h_record = AvgMeter() for i, data in enumerate(train_loader): optimizer.param_groups[0]['lr'] = 2 * args['lr'] * ( 1 - float(curr_iter) / args['iter_num'])**args['lr_decay'] optimizer.param_groups[1]['lr'] = args['lr'] * ( 1 - float(curr_iter) / args['iter_num'])**args['lr_decay'] inputs, labels = data batch_size = inputs.size(0) inputs = Variable(inputs).cuda() labels = Variable(labels).cuda() optimizer.zero_grad() fuse_predict, predict1_h2l, predict2_h2l, predict3_h2l, predict4_h2l, \ predict1_l2h, predict2_l2h, predict3_l2h, predict4_l2h = net(inputs) loss_fuse = bce_logit(fuse_predict, labels) loss1_h2l = bce_logit(predict1_h2l, labels) loss2_h2l = bce_logit(predict2_h2l, labels) loss3_h2l = bce_logit(predict3_h2l, labels) loss4_h2l = bce_logit(predict4_h2l, labels) loss1_l2h = bce_logit(predict1_l2h, labels) loss2_l2h = bce_logit(predict2_l2h, labels) loss3_l2h = bce_logit(predict3_l2h, labels) loss4_l2h = bce_logit(predict4_l2h, labels) loss = loss_fuse + loss1_h2l + loss2_h2l + loss3_h2l + loss4_h2l + loss1_l2h + \ loss2_l2h + loss3_l2h + loss4_l2h loss.backward() optimizer.step() train_loss_record.update(loss.data, batch_size) loss_fuse_record.update(loss_fuse.data, batch_size) loss1_h2l_record.update(loss1_h2l.data, batch_size) loss2_h2l_record.update(loss2_h2l.data, batch_size) loss3_h2l_record.update(loss3_h2l.data, batch_size) loss4_h2l_record.update(loss4_h2l.data, batch_size) loss1_l2h_record.update(loss1_l2h.data, batch_size) loss2_l2h_record.update(loss2_l2h.data, batch_size) loss3_l2h_record.update(loss3_l2h.data, batch_size) loss4_l2h_record.update(loss4_l2h.data, batch_size) curr_iter += 1 log = '[iter %d], [train loss %.5f], [loss_fuse %.5f], [loss1_h2l %.5f], [loss2_h2l %.5f], ' \ '[loss3_h2l %.5f], [loss4_h2l %.5f], [loss1_l2h %.5f], [loss2_l2h %.5f], [loss3_l2h %.5f], ' \ '[loss4_l2h %.5f], [lr %.13f]' % \ (curr_iter, train_loss_record.avg, loss_fuse_record.avg, loss1_h2l_record.avg, loss2_h2l_record.avg, loss3_h2l_record.avg, loss4_h2l_record.avg, loss1_l2h_record.avg, loss2_l2h_record.avg, loss3_l2h_record.avg, loss4_l2h_record.avg, optimizer.param_groups[1]['lr']) print(log) open(log_path, 'a').write(log + '\n') if curr_iter > args['iter_num']: torch.save( net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter)) return
def main(): net = R3Net_prior(motion='GRU', se_layer=False, st_fuse=False) print('load snapshot \'%s\' for testing' % args['snapshot']) net.load_state_dict( torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth'), map_location='cuda:0')) # net = train_online(net) results = {} for name, root in to_test.items(): precision_record, recall_record, = [AvgMeter() for _ in range(256)], [ AvgMeter() for _ in range(256) ] mae_record = AvgMeter() if args['save_results']: check_mkdir( os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']))) folders = os.listdir(root) folders.sort() for folder in folders: net = train_online(net, seq_name=folder) with torch.no_grad(): net.eval() net.cuda() imgs = os.listdir(os.path.join(root, folder)) imgs.sort() for i in range(1, len(imgs) - args['batch_size'] + 1): print(imgs[i]) img_var = [] img_names = [] for j in range(0, args['batch_size']): img = Image.open( os.path.join(root, folder, imgs[i + j])).convert('RGB') img_names.append(imgs[i + j]) shape = img.size img = img.resize(args['input_size']) img_var.append( Variable(img_transform(img).unsqueeze(0), volatile=True).cuda()) img_var = torch.cat(img_var, dim=0) prediction = net(img_var) precision = to_pil(prediction.data.squeeze(0).cpu()) precision = precision.resize(shape) prediction = np.array(precision) if args['crf_refine']: prediction = crf_refine(np.array(img), prediction) gt = np.array( Image.open( os.path.join(gt_root, folder, img_names[-1][:-4] + '.png')).convert('L')) precision, recall, mae = cal_precision_recall_mae( prediction, gt) for pidx, pdata in enumerate(zip(precision, recall)): p, r = pdata precision_record[pidx].update(p) recall_record[pidx].update(r) mae_record.update(mae) if args['save_results']: # folder, sub_name = os.path.split(img_names[-1]) save_path = os.path.join( ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']), folder) if not os.path.exists(save_path): os.makedirs(save_path) Image.fromarray(prediction).save( os.path.join(save_path, img_names[-1][:-4] + '.png')) fmeasure = cal_fmeasure([precord.avg for precord in precision_record], [rrecord.avg for rrecord in recall_record]) results[name] = {'fmeasure': fmeasure, 'mae': mae_record.avg} print('test results:') print(results)
os.path.join(root_inference, folder, img[:-4] + '.png')).convert('L') gt = Image.open(os.path.join(gt_root, folder, img[:-4] + '.png')).convert('L') gt = gt.resize(pred.size) image = image.resize(pred.size) gt = np.array(gt) pred = np.array(pred) precision, recall, mae = cal_precision_recall_mae(pred, gt) for pidx, pdata in enumerate(zip(precision, recall)): p, r = pdata precision_record[pidx].update(p) recall_record[pidx].update(r) mae_record.update(mae) fmeasure = cal_fmeasure([precord.avg for precord in precision_record], [rrecord.avg for rrecord in recall_record]) results[name] = {'fmeasure': fmeasure, 'mae': mae_record.avg} print('test results:') print(results) # {'davis': {'mae': 0.041576569176772944, 'fmeasure': 0.8341383096984007}} # {'MSST_davis': {'fmeasure': 0.8175943834081874, 'mae': 0.04597473876855389}} # {'Amulet_davis': {'mae': 0.08374974551689243, 'fmeasure': 0.7234079968968813}} # {'CG_davis': {'fmeasure': 0.6278087775523111, 'mae': 0.09568971798828023}} # {'CS_davis': {'fmeasure': 0.387371123540425, 'mae': 0.11592338609834756}} # {'DCL_davis': {'fmeasure': 0.7555328232313439, 'mae': 0.1325773794024856}}
def validate(g, curr_epoch, d=None): g.eval() mse_criterion = nn.MSELoss() g_mse_loss_record, psnr_record = AvgMeter(), AvgMeter() for name, loader in val_loader.items(): val_visual = [] # note that the batch size is 1 for i, data in enumerate(loader): hr_img, _ = data lr_img, hr_interpolated_img = val_lr_transform(hr_img.squeeze(0)) lr_img = Variable(lr_img.unsqueeze(0), volatile=True).cuda() hr_interpolated_img = hr_interpolated_img hr_img = Variable(hr_img, volatile=True).cuda() gen_hr_img = g(lr_img) g_mse_loss = mse_criterion(gen_hr_img, hr_img) g_mse_loss_record.update(g_mse_loss.item()) psnr_record.update(10 * np.log10(1 / g_mse_loss.item())) val_visual.extend([ val_display_transform(hr_interpolated_img), val_display_transform(hr_img.cpu().data.squeeze(0)), val_display_transform(gen_hr_img.cpu().data.squeeze(0)) ]) val_visual = torch.stack(val_visual, 0) val_visual = vutils.make_grid(val_visual, nrow=3, padding=5) snapshot_name = 'epoch_%d_%s_g_mse_loss_%.5f_psnr_%.5f' % ( curr_epoch + 1, name, g_mse_loss_record.avg, psnr_record.avg) if d is None: snapshot_name = 'pretrain_' + snapshot_name writer.add_scalar('pretrain_validate_%s_psnr' % name, psnr_record.avg, curr_epoch + 1) writer.add_scalar('pretrain_validate_%s_g_mse_loss' % name, g_mse_loss_record.avg, curr_epoch + 1) print( '[pretrain validate %s]: [epoch %d], [g_mse_loss %.5f], [psnr %.5f]' % (name, curr_epoch + 1, g_mse_loss_record.avg, psnr_record.avg)) else: writer.add_scalar('validate_%s_psnr' % name, psnr_record.avg, curr_epoch + 1) writer.add_scalar('validate_%s_g_mse_loss' % name, g_mse_loss_record.avg, curr_epoch + 1) print( '[validate %s]: [epoch %d], [g_mse_loss %.5f], [psnr %.5f]' % (name, curr_epoch + 1, g_mse_loss_record.avg, psnr_record.avg)) torch.save( d.state_dict(), os.path.join(train_args['ckpt_path'], snapshot_name + '_d.pth')) torch.save( g.state_dict(), os.path.join(train_args['ckpt_path'], snapshot_name + '_g.pth')) writer.add_image(snapshot_name, val_visual) g_mse_loss_record.reset() psnr_record.reset() g.train()