def perform_operation(file_path): torch.no_grad() e_net.eval() a_net.eval() s_net.eval() fusion.eval() imDataset = ImageList(crop_size=args.IM_SIZE, path=file_path, img_path=args.img_path, NUM_CLASS=args.NUM_CLASS, phase='test', transform=prep.image_test(crop_size=args.IM_SIZE), target_transform=prep.land_transform(img_size=args.IM_SIZE)) imDataLoader = torch.utils.data.DataLoader(imDataset, batch_size=args.Test_BATCH, num_workers=0) pbar = tqdm(total=len(imDataLoader)) for batch_Idx, data in enumerate(imDataLoader): datablob, datalb, pos_para = data datablob = torch.autograd.Variable(datablob).cuda() y_lb = torch.autograd.Variable(datalb).view(datalb.size(0), -1).cuda() pos_para = torch.autograd.Variable(pos_para).cuda() pred_global = e_net(datablob) feat_data = e_net.predict_BN(datablob) pred_att_map, pred_conf = a_net(feat_data) slice_feat_data = prep_model_input(pred_att_map, pos_para) pred_local = s_net(slice_feat_data) cls_pred = fusion(pred_global + pred_local) cls_pred = cls_pred.data.cpu().float() y_lb = y_lb.data.cpu().float() if batch_Idx == 0: all_output = cls_pred all_label = y_lb else: all_output = torch.cat((all_output, cls_pred), 0) all_label = torch.cat((all_label, y_lb), 0) pbar.update() pbar.close() all_acc_scr = get_acc(all_output, all_label) all_f1_score = get_f1(all_output, all_label) print('f1 score: ', str(all_f1_score.numpy().tolist())) print('average f1 score: ', str(all_f1_score.mean().numpy().tolist())) print('acc score: ', str(all_acc_scr.numpy().tolist())) print('average acc score: ', str(all_acc_scr.mean().numpy().tolist()))
def main(config): ## set loss criterion use_gpu = torch.cuda.is_available() au_weight = torch.from_numpy(np.loadtxt(config.train_path_prefix + '_weight.txt')) if use_gpu: au_weight = au_weight.float().cuda() else: au_weight = au_weight.float() ## prepare data dsets = {} dset_loaders = {} dsets['train'] = ImageList(crop_size=config.crop_size, path=config.train_path_prefix, transform=prep.image_train(crop_size=config.crop_size), target_transform=prep.land_transform(img_size=config.crop_size, flip_reflect=np.loadtxt( config.flip_reflect))) dset_loaders['train'] = util_data.DataLoader(dsets['train'], batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_workers) dsets['test'] = ImageList(crop_size=config.crop_size, path=config.test_path_prefix, phase='test', transform=prep.image_test(crop_size=config.crop_size), target_transform=prep.land_transform(img_size=config.crop_size, flip_reflect=np.loadtxt( config.flip_reflect)) ) dset_loaders['test'] = util_data.DataLoader(dsets['test'], batch_size=config.eval_batch_size, shuffle=False, num_workers=config.num_workers) ## set network modules region_learning = network.network_dict[config.region_learning](input_dim=3, unit_dim = config.unit_dim) align_net = network.network_dict[config.align_net](crop_size=config.crop_size, map_size=config.map_size, au_num=config.au_num, land_num=config.land_num, input_dim=config.unit_dim*8, fill_coeff=config.fill_coeff) local_attention_refine = network.network_dict[config.local_attention_refine](au_num=config.au_num, unit_dim=config.unit_dim) local_au_net = network.network_dict[config.local_au_net](au_num=config.au_num, input_dim=config.unit_dim*8, unit_dim=config.unit_dim) global_au_feat = network.network_dict[config.global_au_feat](input_dim=config.unit_dim*8, unit_dim=config.unit_dim) au_net = network.network_dict[config.au_net](au_num=config.au_num, input_dim = 12000, unit_dim = config.unit_dim) if config.start_epoch > 0: print('resuming model from epoch %d' %(config.start_epoch)) region_learning.load_state_dict(torch.load( config.write_path_prefix + config.run_name + '/region_learning_' + str(config.start_epoch) + '.pth')) align_net.load_state_dict(torch.load( config.write_path_prefix + config.run_name + '/align_net_' + str(config.start_epoch) + '.pth')) local_attention_refine.load_state_dict(torch.load( config.write_path_prefix + config.run_name + '/local_attention_refine_' + str(config.start_epoch) + '.pth')) local_au_net.load_state_dict(torch.load( config.write_path_prefix + config.run_name + '/local_au_net_' + str(config.start_epoch) + '.pth')) global_au_feat.load_state_dict(torch.load( config.write_path_prefix + config.run_name + '/global_au_feat_' + str(config.start_epoch) + '.pth')) au_net.load_state_dict(torch.load( config.write_path_prefix + config.run_name + '/au_net_' + str(config.start_epoch) + '.pth')) if use_gpu: region_learning = region_learning.cuda() align_net = align_net.cuda() local_attention_refine = local_attention_refine.cuda() local_au_net = local_au_net.cuda() global_au_feat = global_au_feat.cuda() au_net = au_net.cuda() print(region_learning) print(align_net) print(local_attention_refine) print(local_au_net) print(global_au_feat) print(au_net) ## collect parameters region_learning_parameter_list = [{'params': filter(lambda p: p.requires_grad, region_learning.parameters()), 'lr': 1}] align_net_parameter_list = [ {'params': filter(lambda p: p.requires_grad, align_net.parameters()), 'lr': 1}] local_attention_refine_parameter_list = [ {'params': filter(lambda p: p.requires_grad, local_attention_refine.parameters()), 'lr': 1}] local_au_net_parameter_list = [ {'params': filter(lambda p: p.requires_grad, local_au_net.parameters()), 'lr': 1}] global_au_feat_parameter_list = [ {'params': filter(lambda p: p.requires_grad, global_au_feat.parameters()), 'lr': 1}] au_net_parameter_list = [ {'params': filter(lambda p: p.requires_grad, au_net.parameters()), 'lr': 1}] ## set optimizer optimizer = optim_dict[config.optimizer_type](itertools.chain(region_learning_parameter_list, align_net_parameter_list, local_attention_refine_parameter_list, local_au_net_parameter_list, global_au_feat_parameter_list, au_net_parameter_list), lr=1.0, momentum=config.momentum, weight_decay=config.weight_decay, nesterov=config.use_nesterov) param_lr = [] for param_group in optimizer.param_groups: param_lr.append(param_group['lr']) lr_scheduler = lr_schedule.schedule_dict[config.lr_type] if not os.path.exists(config.write_path_prefix + config.run_name): os.makedirs(config.write_path_prefix + config.run_name) if not os.path.exists(config.write_res_prefix + config.run_name): os.makedirs(config.write_res_prefix + config.run_name) res_file = open( config.write_res_prefix + config.run_name + '/AU_pred_' + str(config.start_epoch) + '.txt', 'w') ## train count = 0 for epoch in range(config.start_epoch, config.n_epochs + 1): if epoch > config.start_epoch: print('taking snapshot ...') torch.save(region_learning.state_dict(), config.write_path_prefix + config.run_name + '/region_learning_' + str(epoch) + '.pth') torch.save(align_net.state_dict(), config.write_path_prefix + config.run_name + '/align_net_' + str(epoch) + '.pth') torch.save(local_attention_refine.state_dict(), config.write_path_prefix + config.run_name + '/local_attention_refine_' + str(epoch) + '.pth') torch.save(local_au_net.state_dict(), config.write_path_prefix + config.run_name + '/local_au_net_' + str(epoch) + '.pth') torch.save(global_au_feat.state_dict(), config.write_path_prefix + config.run_name + '/global_au_feat_' + str(epoch) + '.pth') torch.save(au_net.state_dict(), config.write_path_prefix + config.run_name + '/au_net_' + str(epoch) + '.pth') # eval in the train if epoch > config.start_epoch: print('testing ...') region_learning.train(False) align_net.train(False) local_attention_refine.train(False) local_au_net.train(False) global_au_feat.train(False) au_net.train(False) local_f1score_arr, local_acc_arr, f1score_arr, acc_arr, mean_error, failure_rate = AU_detection_evalv2( dset_loaders['test'], region_learning, align_net, local_attention_refine, local_au_net, global_au_feat, au_net, use_gpu=use_gpu) print('epoch =%d, local f1 score mean=%f, local accuracy mean=%f, ' 'f1 score mean=%f, accuracy mean=%f, mean error=%f, failure rate=%f' % (epoch, local_f1score_arr.mean(), local_acc_arr.mean(), f1score_arr.mean(), acc_arr.mean(), mean_error, failure_rate)) print('%d\t%f\t%f\t%f\t%f\t%f\t%f' % (epoch, local_f1score_arr.mean(), local_acc_arr.mean(), f1score_arr.mean(), acc_arr.mean(), mean_error, failure_rate), file=res_file) region_learning.train(True) align_net.train(True) local_attention_refine.train(True) local_au_net.train(True) global_au_feat.train(True) au_net.train(True) if epoch >= config.n_epochs: break for i, batch in enumerate(dset_loaders['train']): if i % config.display == 0 and count > 0: print('[epoch = %d][iter = %d][total_loss = %f][loss_au_softmax = %f][loss_au_dice = %f]' '[loss_local_au_softmax = %f][loss_local_au_dice = %f]' '[loss_land = %f]' % (epoch, i, total_loss.data.cpu().numpy(), loss_au_softmax.data.cpu().numpy(), loss_au_dice.data.cpu().numpy(), loss_local_au_softmax.data.cpu().numpy(), loss_local_au_dice.data.cpu().numpy(), loss_land.data.cpu().numpy())) print('learning rate = %f %f %f %f %f %f' % (optimizer.param_groups[0]['lr'], optimizer.param_groups[1]['lr'], optimizer.param_groups[2]['lr'], optimizer.param_groups[3]['lr'], optimizer.param_groups[4]['lr'], optimizer.param_groups[5]['lr'])) print('the number of training iterations is %d' % (count)) input, land, biocular, au = batch if use_gpu: input, land, biocular, au = input.cuda(), land.float().cuda(), \ biocular.float().cuda(), au.long().cuda() else: au = au.long() optimizer = lr_scheduler(param_lr, optimizer, epoch, config.gamma, config.stepsize, config.init_lr) optimizer.zero_grad() region_feat = region_learning(input) align_feat, align_output, aus_map = align_net(region_feat) if use_gpu: aus_map = aus_map.cuda() output_aus_map = local_attention_refine(aus_map.detach()) local_au_out_feat, local_aus_output = local_au_net(region_feat, output_aus_map) global_au_out_feat = global_au_feat(region_feat) concat_au_feat = torch.cat((align_feat, global_au_out_feat, local_au_out_feat.detach()), 1) aus_output = au_net(concat_au_feat) loss_au_softmax = au_softmax_loss(aus_output, au, weight=au_weight) loss_au_dice = au_dice_loss(aus_output, au, weight=au_weight) loss_au = loss_au_softmax + loss_au_dice loss_local_au_softmax = au_softmax_loss(local_aus_output, au, weight=au_weight) loss_local_au_dice = au_dice_loss(local_aus_output, au, weight=au_weight) loss_local_au = loss_local_au_softmax + loss_local_au_dice loss_land = landmark_loss(align_output, land, biocular) total_loss = config.lambda_au * (loss_au + loss_local_au) + \ config.lambda_land * loss_land total_loss.backward() optimizer.step() count = count + 1 res_file.close()
def main(config): ## set loss criterion use_gpu = torch.cuda.is_available() au_weight_src = torch.from_numpy( np.loadtxt(config.src_train_path_prefix + '_weight.txt')) if use_gpu: au_weight_src = au_weight_src.float().cuda() else: au_weight_src = au_weight_src.float() au_class_criterion = nn.BCEWithLogitsLoss(au_weight_src) land_predict_criterion = land_softmax_loss discriminator_criterion = nn.MSELoss() reconstruct_criterion = nn.L1Loss() land_discriminator_criterion = land_discriminator_loss land_adaptation_criterion = land_adaptation_loss ## prepare data dsets = {} dset_loaders = {} dsets['source'] = {} dset_loaders['source'] = {} dsets['source']['train'] = ImageList_land_au( config.crop_size, config.src_train_path_prefix, transform=prep.image_train(crop_size=config.crop_size), target_transform=prep.land_transform( output_size=config.output_size, scale=config.crop_size / config.output_size, flip_reflect=np.loadtxt(config.flip_reflect))) dset_loaders['source']['train'] = util_data.DataLoader( dsets['source']['train'], batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_workers) dsets['source']['val'] = ImageList_au( config.src_val_path_prefix, transform=prep.image_test(crop_size=config.crop_size)) dset_loaders['source']['val'] = util_data.DataLoader( dsets['source']['val'], batch_size=config.eval_batch_size, shuffle=False, num_workers=config.num_workers) dsets['target'] = {} dset_loaders['target'] = {} dsets['target']['train'] = ImageList_land_au( config.crop_size, config.tgt_train_path_prefix, transform=prep.image_train(crop_size=config.crop_size), target_transform=prep.land_transform( output_size=config.output_size, scale=config.crop_size / config.output_size, flip_reflect=np.loadtxt(config.flip_reflect))) dset_loaders['target']['train'] = util_data.DataLoader( dsets['target']['train'], batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_workers) dsets['target']['val'] = ImageList_au( config.tgt_val_path_prefix, transform=prep.image_test(crop_size=config.crop_size)) dset_loaders['target']['val'] = util_data.DataLoader( dsets['target']['val'], batch_size=config.eval_batch_size, shuffle=False, num_workers=config.num_workers) ## set network modules base_net = network.network_dict[config.base_net]() land_enc = network.network_dict[config.land_enc](land_num=config.land_num) land_enc_store = network.network_dict[config.land_enc]( land_num=config.land_num) au_enc = network.network_dict[config.au_enc](au_num=config.au_num) invar_shape_enc = network.network_dict[config.invar_shape_enc]() feat_gen = network.network_dict[config.feat_gen]() invar_shape_disc = network.network_dict[config.invar_shape_disc]( land_num=config.land_num) feat_gen_disc_src = network.network_dict[config.feat_gen_disc]() feat_gen_disc_tgt = network.network_dict[config.feat_gen_disc]() if config.start_epoch > 0: base_net.load_state_dict( torch.load(config.write_path_prefix + config.mode + '/base_net_' + str(config.start_epoch) + '.pth')) land_enc.load_state_dict( torch.load(config.write_path_prefix + config.mode + '/land_enc_' + str(config.start_epoch) + '.pth')) au_enc.load_state_dict( torch.load(config.write_path_prefix + config.mode + '/au_enc_' + str(config.start_epoch) + '.pth')) invar_shape_enc.load_state_dict( torch.load(config.write_path_prefix + config.mode + '/invar_shape_enc_' + str(config.start_epoch) + '.pth')) feat_gen.load_state_dict( torch.load(config.write_path_prefix + config.mode + '/feat_gen_' + str(config.start_epoch) + '.pth')) invar_shape_disc.load_state_dict( torch.load(config.write_path_prefix + config.mode + '/invar_shape_disc_' + str(config.start_epoch) + '.pth')) feat_gen_disc_src.load_state_dict( torch.load(config.write_path_prefix + config.mode + '/feat_gen_disc_src_' + str(config.start_epoch) + '.pth')) feat_gen_disc_tgt.load_state_dict( torch.load(config.write_path_prefix + config.mode + '/feat_gen_disc_tgt_' + str(config.start_epoch) + '.pth')) if use_gpu: base_net = base_net.cuda() land_enc = land_enc.cuda() land_enc_store = land_enc_store.cuda() au_enc = au_enc.cuda() invar_shape_enc = invar_shape_enc.cuda() feat_gen = feat_gen.cuda() invar_shape_disc = invar_shape_disc.cuda() feat_gen_disc_src = feat_gen_disc_src.cuda() feat_gen_disc_tgt = feat_gen_disc_tgt.cuda() ## collect parameters base_net_parameter_list = [{ 'params': filter(lambda p: p.requires_grad, base_net.parameters()), 'lr': 1 }] land_enc_parameter_list = [{ 'params': filter(lambda p: p.requires_grad, land_enc.parameters()), 'lr': 1 }] au_enc_parameter_list = [{ 'params': filter(lambda p: p.requires_grad, au_enc.parameters()), 'lr': 1 }] invar_shape_enc_parameter_list = [{ 'params': filter(lambda p: p.requires_grad, invar_shape_enc.parameters()), 'lr': 1 }] feat_gen_parameter_list = [{ 'params': filter(lambda p: p.requires_grad, feat_gen.parameters()), 'lr': 1 }] invar_shape_disc_parameter_list = [{ 'params': filter(lambda p: p.requires_grad, invar_shape_disc.parameters()), 'lr': 1 }] feat_gen_disc_src_parameter_list = [{ 'params': filter(lambda p: p.requires_grad, feat_gen_disc_src.parameters()), 'lr': 1 }] feat_gen_disc_tgt_parameter_list = [{ 'params': filter(lambda p: p.requires_grad, feat_gen_disc_tgt.parameters()), 'lr': 1 }] ## set optimizer Gen_optimizer = optim_dict[config.gen_optimizer_type](itertools.chain( invar_shape_enc_parameter_list, feat_gen_parameter_list), 1.0, [config.gen_beta1, config.gen_beta2]) Task_optimizer = optim_dict[config.task_optimizer_type](itertools.chain( base_net_parameter_list, land_enc_parameter_list, au_enc_parameter_list), 1.0, [config.task_beta1, config.task_beta2]) Disc_optimizer = optim_dict[config.gen_optimizer_type]( itertools.chain(invar_shape_disc_parameter_list, feat_gen_disc_src_parameter_list, feat_gen_disc_tgt_parameter_list), 1.0, [config.gen_beta1, config.gen_beta2]) Gen_param_lr = [] for param_group in Gen_optimizer.param_groups: Gen_param_lr.append(param_group['lr']) Task_param_lr = [] for param_group in Task_optimizer.param_groups: Task_param_lr.append(param_group['lr']) Disc_param_lr = [] for param_group in Disc_optimizer.param_groups: Disc_param_lr.append(param_group['lr']) Gen_lr_scheduler = lr_schedule.schedule_dict[config.gen_lr_type] Task_lr_scheduler = lr_schedule.schedule_dict[config.task_lr_type] Disc_lr_scheduler = lr_schedule.schedule_dict[config.gen_lr_type] print(base_net, land_enc, au_enc, invar_shape_enc, feat_gen) print(invar_shape_disc, feat_gen_disc_src, feat_gen_disc_tgt) if not os.path.exists(config.write_path_prefix + config.mode): os.makedirs(config.write_path_prefix + config.mode) if not os.path.exists(config.write_res_prefix + config.mode): os.makedirs(config.write_res_prefix + config.mode) val_type = 'target' # 'source' res_file = open( config.write_res_prefix + config.mode + '/' + val_type + '_AU_pred_' + str(config.start_epoch) + '.txt', 'w') ## train len_train_tgt = len(dset_loaders['target']['train']) count = 0 for epoch in range(config.start_epoch, config.n_epochs + 1): # eval in the train if epoch >= config.start_epoch: base_net.train(False) land_enc.train(False) au_enc.train(False) invar_shape_enc.train(False) feat_gen.train(False) if val_type == 'source': f1score_arr, acc_arr = AU_detection_eval_src( dset_loaders[val_type]['val'], base_net, au_enc, use_gpu=use_gpu) else: f1score_arr, acc_arr = AU_detection_eval_tgt( dset_loaders[val_type]['val'], base_net, land_enc, au_enc, invar_shape_enc, feat_gen, use_gpu=use_gpu) print('epoch =%d, f1 score mean=%f, accuracy mean=%f' % (epoch, f1score_arr.mean(), acc_arr.mean())) print >> res_file, '%d\t%f\t%f' % (epoch, f1score_arr.mean(), acc_arr.mean()) base_net.train(True) land_enc.train(True) au_enc.train(True) invar_shape_enc.train(True) feat_gen.train(True) if epoch > config.start_epoch: print('taking snapshot ...') torch.save( base_net.state_dict(), config.write_path_prefix + config.mode + '/base_net_' + str(epoch) + '.pth') torch.save( land_enc.state_dict(), config.write_path_prefix + config.mode + '/land_enc_' + str(epoch) + '.pth') torch.save( au_enc.state_dict(), config.write_path_prefix + config.mode + '/au_enc_' + str(epoch) + '.pth') torch.save( invar_shape_enc.state_dict(), config.write_path_prefix + config.mode + '/invar_shape_enc_' + str(epoch) + '.pth') torch.save( feat_gen.state_dict(), config.write_path_prefix + config.mode + '/feat_gen_' + str(epoch) + '.pth') torch.save( invar_shape_disc.state_dict(), config.write_path_prefix + config.mode + '/invar_shape_disc_' + str(epoch) + '.pth') torch.save( feat_gen_disc_src.state_dict(), config.write_path_prefix + config.mode + '/feat_gen_disc_src_' + str(epoch) + '.pth') torch.save( feat_gen_disc_tgt.state_dict(), config.write_path_prefix + config.mode + '/feat_gen_disc_tgt_' + str(epoch) + '.pth') if epoch >= config.n_epochs: break for i, batch_src in enumerate(dset_loaders['source']['train']): if i % config.display == 0 and count > 0: print( '[epoch = %d][iter = %d][loss_disc = %f][loss_invar_shape_disc = %f][loss_gen_disc = %f][total_loss = %f][loss_invar_shape_adaptation = %f][loss_gen_adaptation = %f][loss_self_recons = %f][loss_gen_cycle = %f][loss_au = %f][loss_land = %f]' % (epoch, i, loss_disc.data.cpu().numpy(), loss_invar_shape_disc.data.cpu().numpy(), loss_gen_disc.data.cpu().numpy(), total_loss.data.cpu().numpy(), loss_invar_shape_adaptation.data.cpu().numpy(), loss_gen_adaptation.data.cpu().numpy(), loss_self_recons.data.cpu().numpy(), loss_gen_cycle.data.cpu().numpy(), loss_au.data.cpu().numpy(), loss_land.data.cpu().numpy())) print('learning rate = %f, %f, %f' % (Disc_optimizer.param_groups[0]['lr'], Gen_optimizer.param_groups[0]['lr'], Task_optimizer.param_groups[0]['lr'])) print('the number of training iterations is %d' % (count)) input_src, land_src, au_src = batch_src if count % len_train_tgt == 0: if count > 0: dset_loaders['target']['train'] = util_data.DataLoader( dsets['target']['train'], batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_workers) iter_data_tgt = iter(dset_loaders['target']['train']) input_tgt, land_tgt, au_tgt = iter_data_tgt.next() if input_tgt.size(0) > input_src.size(0): input_tgt, land_tgt, au_tgt = input_tgt[ 0:input_src.size(0), :, :, :], land_tgt[ 0:input_src.size(0), :], au_tgt[0:input_src.size(0)] elif input_tgt.size(0) < input_src.size(0): input_src, land_src, au_src = input_src[ 0:input_tgt.size(0), :, :, :], land_src[ 0:input_tgt.size(0), :], au_src[0:input_tgt.size(0)] if use_gpu: input_src, land_src, au_src, input_tgt, land_tgt, au_tgt = \ input_src.cuda(), land_src.long().cuda(), au_src.float().cuda(), \ input_tgt.cuda(), land_tgt.long().cuda(), au_tgt.float().cuda() else: land_src, au_src, land_tgt, au_tgt = \ land_src.long(), au_src.float(), land_tgt.long(), au_tgt.float() land_enc_store.load_state_dict(land_enc.state_dict()) base_feat_src = base_net(input_src) align_attention_src, align_feat_src, align_output_src = land_enc( base_feat_src) au_feat_src, au_output_src = au_enc(base_feat_src) base_feat_tgt = base_net(input_tgt) align_attention_tgt, align_feat_tgt, align_output_tgt = land_enc( base_feat_tgt) au_feat_tgt, au_output_tgt = au_enc(base_feat_tgt) invar_shape_output_src = invar_shape_enc(base_feat_src.detach()) invar_shape_output_tgt = invar_shape_enc(base_feat_tgt.detach()) # new_gen new_gen_tgt = feat_gen(align_attention_src.detach(), invar_shape_output_tgt) new_gen_src = feat_gen(align_attention_tgt.detach(), invar_shape_output_src) # recons_gen recons_gen_src = feat_gen(align_attention_src.detach(), invar_shape_output_src) recons_gen_tgt = feat_gen(align_attention_tgt.detach(), invar_shape_output_tgt) # new2_gen new_gen_invar_shape_output_src = invar_shape_enc( new_gen_src.detach()) new_gen_invar_shape_output_tgt = invar_shape_enc( new_gen_tgt.detach()) new_gen_align_attention_src, new_gen_align_feat_src, new_gen_align_output_src = land_enc_store( new_gen_src) new_gen_align_attention_tgt, new_gen_align_feat_tgt, new_gen_align_output_tgt = land_enc_store( new_gen_tgt) new2_gen_tgt = feat_gen(new_gen_align_attention_src.detach(), new_gen_invar_shape_output_tgt) new2_gen_src = feat_gen(new_gen_align_attention_tgt.detach(), new_gen_invar_shape_output_src) ############################ # 1. train discriminator # ############################ Disc_optimizer = Disc_lr_scheduler(Disc_param_lr, Disc_optimizer, epoch, config.n_epochs, 1, config.decay_start_epoch, config.gen_lr) Disc_optimizer.zero_grad() align_output_invar_shape_src = invar_shape_disc( invar_shape_output_src.detach()) align_output_invar_shape_tgt = invar_shape_disc( invar_shape_output_tgt.detach()) # loss_invar_shape_disc loss_base_invar_shape_disc_src = land_discriminator_criterion( align_output_invar_shape_src, land_src) loss_base_invar_shape_disc_tgt = land_discriminator_criterion( align_output_invar_shape_tgt, land_tgt) loss_invar_shape_disc = (loss_base_invar_shape_disc_src + loss_base_invar_shape_disc_tgt) * 0.5 base_gen_src_pred = feat_gen_disc_src(base_feat_src.detach()) new_gen_src_pred = feat_gen_disc_src(new_gen_src.detach()) real_label = torch.ones((base_feat_src.size(0), 1)) fake_label = torch.zeros((base_feat_src.size(0), 1)) if use_gpu: real_label, fake_label = real_label.cuda(), fake_label.cuda() # loss_gen_disc_src loss_base_gen_src = discriminator_criterion( base_gen_src_pred, real_label) loss_new_gen_src = discriminator_criterion(new_gen_src_pred, fake_label) loss_gen_disc_src = (loss_base_gen_src + loss_new_gen_src) * 0.5 base_gen_tgt_pred = feat_gen_disc_tgt(base_feat_tgt.detach()) new_gen_tgt_pred = feat_gen_disc_tgt(new_gen_tgt.detach()) # loss_gen_disc_tgt loss_base_gen_tgt = discriminator_criterion( base_gen_tgt_pred, real_label) loss_new_gen_tgt = discriminator_criterion(new_gen_tgt_pred, fake_label) loss_gen_disc_tgt = (loss_base_gen_tgt + loss_new_gen_tgt) * 0.5 # loss_gen_disc loss_gen_disc = (loss_gen_disc_src + loss_gen_disc_tgt) * 0.5 loss_disc = loss_invar_shape_disc + loss_gen_disc loss_disc.backward() # optimize discriminator Disc_optimizer.step() ############################ # 2. train base network # ############################ Gen_optimizer = Gen_lr_scheduler(Gen_param_lr, Gen_optimizer, epoch, config.n_epochs, 1, config.decay_start_epoch, config.gen_lr) Gen_optimizer.zero_grad() Task_optimizer = Task_lr_scheduler(Task_param_lr, Task_optimizer, epoch, config.n_epochs, 1, config.decay_start_epoch, config.task_lr) Task_optimizer.zero_grad() align_output_invar_shape_src = invar_shape_disc( invar_shape_output_src) align_output_invar_shape_tgt = invar_shape_disc( invar_shape_output_tgt) # loss_invar_shape_adaptation loss_base_invar_shape_adaptation_src = land_adaptation_criterion( align_output_invar_shape_src) loss_base_invar_shape_adaptation_tgt = land_adaptation_criterion( align_output_invar_shape_tgt) loss_invar_shape_adaptation = ( loss_base_invar_shape_adaptation_src + loss_base_invar_shape_adaptation_tgt) * 0.5 new_gen_src_pred = feat_gen_disc_src(new_gen_src) loss_gen_adaptation_src = discriminator_criterion( new_gen_src_pred, real_label) new_gen_tgt_pred = feat_gen_disc_tgt(new_gen_tgt) loss_gen_adaptation_tgt = discriminator_criterion( new_gen_tgt_pred, real_label) # loss_gen_adaptation loss_gen_adaptation = (loss_gen_adaptation_src + loss_gen_adaptation_tgt) * 0.5 loss_gen_cycle_src = reconstruct_criterion(new2_gen_src, base_feat_src.detach()) loss_gen_cycle_tgt = reconstruct_criterion(new2_gen_tgt, base_feat_tgt.detach()) # loss_gen_cycle loss_gen_cycle = (loss_gen_cycle_src + loss_gen_cycle_tgt) * 0.5 loss_self_recons_src = reconstruct_criterion( recons_gen_src, base_feat_src.detach()) loss_self_recons_tgt = reconstruct_criterion( recons_gen_tgt, base_feat_tgt.detach()) # loss_self_recons loss_self_recons = (loss_self_recons_src + loss_self_recons_tgt) * 0.5 loss_base_gen_au_src = au_class_criterion(au_output_src, au_src) loss_base_gen_au_tgt = au_class_criterion(au_output_tgt, au_tgt) loss_base_gen_land_src = land_predict_criterion( align_output_src, land_src) loss_base_gen_land_tgt = land_predict_criterion( align_output_tgt, land_tgt) new_gen_au_feat_src, new_gen_au_output_src = au_enc(new_gen_src) new_gen_au_feat_tgt, new_gen_au_output_tgt = au_enc(new_gen_tgt) loss_new_gen_au_src = au_class_criterion(new_gen_au_output_src, au_tgt) loss_new_gen_au_tgt = au_class_criterion(new_gen_au_output_tgt, au_src) loss_new_gen_land_src = land_predict_criterion( new_gen_align_output_src, land_tgt) loss_new_gen_land_tgt = land_predict_criterion( new_gen_align_output_tgt, land_src) # loss_land loss_land = (loss_base_gen_land_src + loss_base_gen_land_tgt + loss_new_gen_land_src + loss_new_gen_land_tgt) * 0.5 # loss_au if config.mode == 'weak': loss_au = (loss_base_gen_au_src + loss_new_gen_au_tgt) * 0.5 else: loss_au = (loss_base_gen_au_src + loss_base_gen_au_tgt + loss_new_gen_au_src + loss_new_gen_au_tgt) * 0.25 total_loss = config.lambda_land_adv * loss_invar_shape_adaptation + \ config.lambda_feat_adv * loss_gen_adaptation + \ config.lambda_cross_cycle * loss_gen_cycle + config.lambda_self_recons * loss_self_recons + \ config.lambda_au * loss_au + config.lambda_land * loss_land total_loss.backward() Gen_optimizer.step() Task_optimizer.step() count = count + 1 res_file.close()
def main(config): ## set loss criterion use_gpu = torch.cuda.is_available() ## prepare data dsets = {} dset_loaders = {} dsets['test'] = ImageList( crop_size=config.crop_size, path=config.test_path_prefix, phase='test', transform=prep.image_test(crop_size=config.crop_size), target_transform=prep.land_transform(img_size=config.crop_size, flip_reflect=np.loadtxt( config.flip_reflect))) dset_loaders['test'] = util_data.DataLoader( dsets['test'], batch_size=config.eval_batch_size, shuffle=False, num_workers=config.num_workers) ## set network modules region_learning = network.network_dict[config.region_learning]( input_dim=3, unit_dim=config.unit_dim) align_net = network.network_dict[config.align_net]( crop_size=config.crop_size, map_size=config.map_size, au_num=config.au_num, land_num=config.land_num, input_dim=config.unit_dim * 8) local_attention_refine = network.network_dict[ config.local_attention_refine](au_num=config.au_num, unit_dim=config.unit_dim) local_au_net = network.network_dict[config.local_au_net]( au_num=config.au_num, input_dim=config.unit_dim * 8, unit_dim=config.unit_dim) global_au_feat = network.network_dict[config.global_au_feat]( input_dim=config.unit_dim * 8, unit_dim=config.unit_dim) au_net = network.network_dict[config.au_net](au_num=config.au_num, input_dim=12000, unit_dim=config.unit_dim) if use_gpu: region_learning = region_learning.cuda() align_net = align_net.cuda() local_attention_refine = local_attention_refine.cuda() local_au_net = local_au_net.cuda() global_au_feat = global_au_feat.cuda() au_net = au_net.cuda() if not os.path.exists(config.write_path_prefix + config.run_name): os.makedirs(config.write_path_prefix + config.run_name) if not os.path.exists(config.write_res_prefix + config.run_name): os.makedirs(config.write_res_prefix + config.run_name) if config.start_epoch <= 0: raise (RuntimeError('start_epoch should be larger than 0\n')) res_file = open( config.write_res_prefix + config.run_name + '/' + config.prefix + 'offline_AU_pred_' + str(config.start_epoch) + '.txt', 'w') region_learning.train(False) align_net.train(False) local_attention_refine.train(False) local_au_net.train(False) global_au_feat.train(False) au_net.train(False) for epoch in range(config.start_epoch, config.n_epochs + 1): region_learning.load_state_dict( torch.load(config.write_path_prefix + config.run_name + '/region_learning_' + str(epoch) + '.pth')) align_net.load_state_dict( torch.load(config.write_path_prefix + config.run_name + '/align_net_' + str(epoch) + '.pth')) local_attention_refine.load_state_dict( torch.load(config.write_path_prefix + config.run_name + '/local_attention_refine_' + str(epoch) + '.pth')) local_au_net.load_state_dict( torch.load(config.write_path_prefix + config.run_name + '/local_au_net_' + str(epoch) + '.pth')) global_au_feat.load_state_dict( torch.load(config.write_path_prefix + config.run_name + '/global_au_feat_' + str(epoch) + '.pth')) au_net.load_state_dict( torch.load(config.write_path_prefix + config.run_name + '/au_net_' + str(epoch) + '.pth')) if config.pred_AU: local_f1score_arr, local_acc_arr, f1score_arr, acc_arr, mean_error, failure_rate = AU_detection_evalv2( dset_loaders['test'], region_learning, align_net, local_attention_refine, local_au_net, global_au_feat, au_net, use_gpu=use_gpu) print( 'epoch =%d, local f1 score mean=%f, local accuracy mean=%f, ' 'f1 score mean=%f, accuracy mean=%f, mean error=%f, failure rate=%f' % (epoch, local_f1score_arr.mean(), local_acc_arr.mean(), f1score_arr.mean(), acc_arr.mean(), mean_error, failure_rate)) print( '%d\t%f\t%f\t%f\t%f\t%f\t%f' % (epoch, local_f1score_arr.mean(), local_acc_arr.mean(), f1score_arr.mean(), acc_arr.mean(), mean_error, failure_rate), file=res_file) if config.vis_attention: if not os.path.exists(config.write_res_prefix + config.run_name + '/vis_map/' + str(epoch)): os.makedirs(config.write_res_prefix + config.run_name + '/vis_map/' + str(epoch)) if not os.path.exists(config.write_res_prefix + config.run_name + '/overlay_vis_map/' + str(epoch)): os.makedirs(config.write_res_prefix + config.run_name + '/overlay_vis_map/' + str(epoch)) vis_attention(dset_loaders['test'], region_learning, align_net, local_attention_refine, config.write_res_prefix, config.run_name, epoch, use_gpu=use_gpu) res_file.close()
def perform_operation(file_path, operation, epoch): if operation == 'Train': torch.enable_grad() e_net.train() a_net.train() s_net.train() fusion.train() else: torch.no_grad() e_net.eval() a_net.eval() s_net.eval() fusion.eval() if operation == 'Train': imDataset = ImageList( crop_size=args.IM_SIZE, path=file_path, img_path=args.img_path, NUM_CLASS=args.NUM_CLASS, phase='test', transform=prep.image_test(crop_size=args.IM_SIZE), target_transform=prep.land_transform(img_size=args.IM_SIZE)) imDataLoader = torch.utils.data.DataLoader(imDataset, batch_size=args.Train_BATCH, shuffle=True, num_workers=0) else: imDataset = ImageList( crop_size=args.IM_SIZE, path=file_path, img_path=args.img_path, NUM_CLASS=args.NUM_CLASS, phase='test', transform=prep.image_test(crop_size=args.IM_SIZE), target_transform=prep.land_transform(img_size=args.IM_SIZE)) imDataLoader = torch.utils.data.DataLoader(imDataset, batch_size=args.Test_BATCH, num_workers=0) for batch_Idx, data in enumerate(imDataLoader): if operation == 'Train': print('%s Epoch: %d Batch_Idx: %d' % (operation, epoch, batch_Idx)) if operation == 'Train': optimizer.zero_grad() datablob, datalb, pos_para = data datablob = torch.autograd.Variable(datablob).cuda() y_lb = torch.autograd.Variable(datalb).view(datalb.size(0), -1).cuda() pos_para = torch.autograd.Variable(pos_para).cuda() bceLoss_cls = nn.BCEWithLogitsLoss() bceLoss2_att = nn.BCEWithLogitsLoss() pred_global = e_net(datablob) feat_data = e_net.predict_BN(datablob) pred_att_map, pred_conf = a_net(feat_data) slice_feat_data = prep_model_input(pred_att_map, pos_para) pred_local = s_net(slice_feat_data) cls_pred = fusion(pred_global + pred_local) cls_loss = bceLoss_cls(cls_pred, y_lb) att_loss = bceLoss2_att(pred_conf, y_lb) sum_loss = cls_loss + att_loss if operation == 'Train': sum_loss.backward() cls_pred = cls_pred.data.cpu().float() y_lb = y_lb.data.cpu().float() f1_score = get_f1(cls_pred, y_lb) acc_scr = get_acc(cls_pred, y_lb) if operation == 'Test': if batch_Idx == 0: all_output = cls_pred all_label = y_lb else: all_output = torch.cat((all_output, cls_pred), 0) all_label = torch.cat((all_label, y_lb), 0) if operation == 'Train': print('acc_scr', acc_scr.mean().cpu().data.item(), 'f1_score', f1_score.mean().cpu().data.item(), 'sum_loss', sum_loss.cpu().data.item()) if operation == 'Train': optimizer.step() if operation == 'Test': fout_test.write('Label:' + str(y_lb) + '->' + 'Pre:' + str(cls_pred) + '\n') del datablob, y_lb, pos_para, feat_data, pred_att_map, pred_conf, slice_feat_data, pred_local, cls_pred, cls_loss, att_loss, sum_loss, acc_scr, f1_score if operation == 'Test': all_acc_scr = get_acc(all_output, all_label) all_f1_score = get_f1(all_output, all_label) fout_test_f1.write('***' + str(all_f1_score.numpy().tolist()) + '\n') fout_test_f1_mean.write('***' + str(all_f1_score.mean().numpy().tolist()) + '\n') fout_test_acc.write('***' + str(all_acc_scr.numpy().tolist()) + '\n') fout_test_acc_mean.write('***' + str(all_acc_scr.mean().numpy().tolist()) + '\n') print('average f1 score: ', str(all_f1_score.mean().numpy().tolist())) print('average acc score: ', str(all_acc_scr.mean().numpy().tolist())) del all_acc_scr, all_f1_score, all_output, all_label if operation == 'Train': new_model = './result/snap/' + args.version + '/WS-DAFNet_' + args.name + '_' + str( epoch) + '.pth' torch.save([e_net, a_net, s_net, fusion], new_model) print('save ' + new_model)