Exemplo n.º 1
0
def perform_operation(file_path):
    torch.no_grad()
    e_net.eval()
    a_net.eval()
    s_net.eval()
    fusion.eval()

    imDataset = ImageList(crop_size=args.IM_SIZE, path=file_path, img_path=args.img_path, NUM_CLASS=args.NUM_CLASS,
              phase='test', transform=prep.image_test(crop_size=args.IM_SIZE),
              target_transform=prep.land_transform(img_size=args.IM_SIZE))
    imDataLoader = torch.utils.data.DataLoader(imDataset, batch_size=args.Test_BATCH, num_workers=0)

    pbar = tqdm(total=len(imDataLoader))
    for batch_Idx, data in enumerate(imDataLoader):

        datablob, datalb, pos_para = data
        datablob = torch.autograd.Variable(datablob).cuda()
        y_lb = torch.autograd.Variable(datalb).view(datalb.size(0), -1).cuda()
        pos_para = torch.autograd.Variable(pos_para).cuda()

        pred_global = e_net(datablob)
        feat_data = e_net.predict_BN(datablob)
        pred_att_map, pred_conf = a_net(feat_data)
        slice_feat_data = prep_model_input(pred_att_map, pos_para)
        pred_local = s_net(slice_feat_data)
        cls_pred = fusion(pred_global + pred_local)

        cls_pred = cls_pred.data.cpu().float()
        y_lb = y_lb.data.cpu().float()

        if batch_Idx == 0:
            all_output = cls_pred
            all_label = y_lb
        else:
            all_output = torch.cat((all_output, cls_pred), 0)
            all_label = torch.cat((all_label, y_lb), 0)
        pbar.update()

    pbar.close()
    all_acc_scr = get_acc(all_output, all_label)
    all_f1_score = get_f1(all_output, all_label)

    print('f1 score: ', str(all_f1_score.numpy().tolist()))
    print('average f1 score: ', str(all_f1_score.mean().numpy().tolist()))
    print('acc score: ', str(all_acc_scr.numpy().tolist()))
    print('average acc score: ', str(all_acc_scr.mean().numpy().tolist()))
Exemplo n.º 2
0
def main(config):
    ## set loss criterion
    use_gpu = torch.cuda.is_available()
    au_weight = torch.from_numpy(np.loadtxt(config.train_path_prefix + '_weight.txt'))
    if use_gpu:
        au_weight = au_weight.float().cuda()
    else:
        au_weight = au_weight.float()

    ## prepare data
    dsets = {}
    dset_loaders = {}

    dsets['train'] = ImageList(crop_size=config.crop_size, path=config.train_path_prefix,
                                       transform=prep.image_train(crop_size=config.crop_size),
                                       target_transform=prep.land_transform(img_size=config.crop_size,
                                                                            flip_reflect=np.loadtxt(
                                                                                config.flip_reflect)))

    dset_loaders['train'] = util_data.DataLoader(dsets['train'], batch_size=config.train_batch_size,
                                                 shuffle=True, num_workers=config.num_workers)

    dsets['test'] = ImageList(crop_size=config.crop_size, path=config.test_path_prefix, phase='test',
                                       transform=prep.image_test(crop_size=config.crop_size),
                                       target_transform=prep.land_transform(img_size=config.crop_size,
                                                                            flip_reflect=np.loadtxt(
                                                                                config.flip_reflect))
                                       )

    dset_loaders['test'] = util_data.DataLoader(dsets['test'], batch_size=config.eval_batch_size,
                                                shuffle=False, num_workers=config.num_workers)

    ## set network modules
    region_learning = network.network_dict[config.region_learning](input_dim=3, unit_dim = config.unit_dim)
    align_net = network.network_dict[config.align_net](crop_size=config.crop_size, map_size=config.map_size,
                                                           au_num=config.au_num, land_num=config.land_num,
                                                           input_dim=config.unit_dim*8, fill_coeff=config.fill_coeff)
    local_attention_refine = network.network_dict[config.local_attention_refine](au_num=config.au_num, unit_dim=config.unit_dim)
    local_au_net = network.network_dict[config.local_au_net](au_num=config.au_num, input_dim=config.unit_dim*8,
                                                                                     unit_dim=config.unit_dim)
    global_au_feat = network.network_dict[config.global_au_feat](input_dim=config.unit_dim*8,
                                                                                     unit_dim=config.unit_dim)
    au_net = network.network_dict[config.au_net](au_num=config.au_num, input_dim = 12000, unit_dim = config.unit_dim)


    if config.start_epoch > 0:
        print('resuming model from epoch %d' %(config.start_epoch))
        region_learning.load_state_dict(torch.load(
            config.write_path_prefix + config.run_name + '/region_learning_' + str(config.start_epoch) + '.pth'))
        align_net.load_state_dict(torch.load(
            config.write_path_prefix + config.run_name + '/align_net_' + str(config.start_epoch) + '.pth'))
        local_attention_refine.load_state_dict(torch.load(
            config.write_path_prefix + config.run_name + '/local_attention_refine_' + str(config.start_epoch) + '.pth'))
        local_au_net.load_state_dict(torch.load(
            config.write_path_prefix + config.run_name + '/local_au_net_' + str(config.start_epoch) + '.pth'))
        global_au_feat.load_state_dict(torch.load(
            config.write_path_prefix + config.run_name + '/global_au_feat_' + str(config.start_epoch) + '.pth'))
        au_net.load_state_dict(torch.load(
            config.write_path_prefix + config.run_name + '/au_net_' + str(config.start_epoch) + '.pth'))

    if use_gpu:
        region_learning = region_learning.cuda()
        align_net = align_net.cuda()
        local_attention_refine = local_attention_refine.cuda()
        local_au_net = local_au_net.cuda()
        global_au_feat = global_au_feat.cuda()
        au_net = au_net.cuda()

    print(region_learning)
    print(align_net)
    print(local_attention_refine)
    print(local_au_net)
    print(global_au_feat)
    print(au_net)

    ## collect parameters
    region_learning_parameter_list = [{'params': filter(lambda p: p.requires_grad, region_learning.parameters()), 'lr': 1}]
    align_net_parameter_list = [
        {'params': filter(lambda p: p.requires_grad, align_net.parameters()), 'lr': 1}]
    local_attention_refine_parameter_list = [
        {'params': filter(lambda p: p.requires_grad, local_attention_refine.parameters()), 'lr': 1}]
    local_au_net_parameter_list = [
        {'params': filter(lambda p: p.requires_grad, local_au_net.parameters()), 'lr': 1}]
    global_au_feat_parameter_list = [
        {'params': filter(lambda p: p.requires_grad, global_au_feat.parameters()), 'lr': 1}]
    au_net_parameter_list = [
        {'params': filter(lambda p: p.requires_grad, au_net.parameters()), 'lr': 1}]

    ## set optimizer
    optimizer = optim_dict[config.optimizer_type](itertools.chain(region_learning_parameter_list, align_net_parameter_list,
                                                                  local_attention_refine_parameter_list,
                                                                  local_au_net_parameter_list,
                                                                  global_au_feat_parameter_list,
                                                                  au_net_parameter_list),
                                                  lr=1.0, momentum=config.momentum, weight_decay=config.weight_decay,
                                                  nesterov=config.use_nesterov)
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group['lr'])

    lr_scheduler = lr_schedule.schedule_dict[config.lr_type]

    if not os.path.exists(config.write_path_prefix + config.run_name):
        os.makedirs(config.write_path_prefix + config.run_name)
    if not os.path.exists(config.write_res_prefix + config.run_name):
        os.makedirs(config.write_res_prefix + config.run_name)

    res_file = open(
        config.write_res_prefix + config.run_name + '/AU_pred_' + str(config.start_epoch) + '.txt', 'w')

    ## train
    count = 0

    for epoch in range(config.start_epoch, config.n_epochs + 1):
        if epoch > config.start_epoch:
            print('taking snapshot ...')
            torch.save(region_learning.state_dict(),
                       config.write_path_prefix + config.run_name + '/region_learning_' + str(epoch) + '.pth')
            torch.save(align_net.state_dict(),
                       config.write_path_prefix + config.run_name + '/align_net_' + str(epoch) + '.pth')
            torch.save(local_attention_refine.state_dict(),
                       config.write_path_prefix + config.run_name + '/local_attention_refine_' + str(epoch) + '.pth')
            torch.save(local_au_net.state_dict(),
                       config.write_path_prefix + config.run_name + '/local_au_net_' + str(epoch) + '.pth')
            torch.save(global_au_feat.state_dict(),
                       config.write_path_prefix + config.run_name + '/global_au_feat_' + str(epoch) + '.pth')
            torch.save(au_net.state_dict(),
                       config.write_path_prefix + config.run_name + '/au_net_' + str(epoch) + '.pth')

        # eval in the train
        if epoch > config.start_epoch:
            print('testing ...')
            region_learning.train(False)
            align_net.train(False)
            local_attention_refine.train(False)
            local_au_net.train(False)
            global_au_feat.train(False)
            au_net.train(False)

            local_f1score_arr, local_acc_arr, f1score_arr, acc_arr, mean_error, failure_rate = AU_detection_evalv2(
                dset_loaders['test'], region_learning, align_net, local_attention_refine,
                local_au_net, global_au_feat, au_net, use_gpu=use_gpu)
            print('epoch =%d, local f1 score mean=%f, local accuracy mean=%f, '
                  'f1 score mean=%f, accuracy mean=%f, mean error=%f, failure rate=%f' % (epoch, local_f1score_arr.mean(),
                                local_acc_arr.mean(), f1score_arr.mean(),
                                acc_arr.mean(), mean_error, failure_rate))
            print('%d\t%f\t%f\t%f\t%f\t%f\t%f' % (epoch, local_f1score_arr.mean(),
                                                local_acc_arr.mean(), f1score_arr.mean(),
                                                acc_arr.mean(), mean_error, failure_rate), file=res_file)

            region_learning.train(True)
            align_net.train(True)
            local_attention_refine.train(True)
            local_au_net.train(True)
            global_au_feat.train(True)
            au_net.train(True)

        if epoch >= config.n_epochs:
            break

        for i, batch in enumerate(dset_loaders['train']):
            if i % config.display == 0 and count > 0:
                print('[epoch = %d][iter = %d][total_loss = %f][loss_au_softmax = %f][loss_au_dice = %f]'
                      '[loss_local_au_softmax = %f][loss_local_au_dice = %f]'
                      '[loss_land = %f]' % (epoch, i,
                    total_loss.data.cpu().numpy(), loss_au_softmax.data.cpu().numpy(), loss_au_dice.data.cpu().numpy(),
                    loss_local_au_softmax.data.cpu().numpy(), loss_local_au_dice.data.cpu().numpy(), loss_land.data.cpu().numpy()))
                print('learning rate = %f %f %f %f %f %f' % (optimizer.param_groups[0]['lr'],
                                                          optimizer.param_groups[1]['lr'],
                                                          optimizer.param_groups[2]['lr'],
                                                          optimizer.param_groups[3]['lr'],
                                                          optimizer.param_groups[4]['lr'],
                                                          optimizer.param_groups[5]['lr']))
                print('the number of training iterations is %d' % (count))

            input, land, biocular, au = batch

            if use_gpu:
                input, land, biocular, au = input.cuda(), land.float().cuda(), \
                                            biocular.float().cuda(), au.long().cuda()
            else:
                au = au.long()

            optimizer = lr_scheduler(param_lr, optimizer, epoch, config.gamma, config.stepsize, config.init_lr)
            optimizer.zero_grad()

            region_feat = region_learning(input)
            align_feat, align_output, aus_map = align_net(region_feat)
            if use_gpu:
                aus_map = aus_map.cuda()
            output_aus_map = local_attention_refine(aus_map.detach())
            local_au_out_feat, local_aus_output = local_au_net(region_feat, output_aus_map)
            global_au_out_feat = global_au_feat(region_feat)
            concat_au_feat = torch.cat((align_feat, global_au_out_feat, local_au_out_feat.detach()), 1)
            aus_output = au_net(concat_au_feat)

            loss_au_softmax = au_softmax_loss(aus_output, au, weight=au_weight)
            loss_au_dice = au_dice_loss(aus_output, au, weight=au_weight)
            loss_au = loss_au_softmax + loss_au_dice

            loss_local_au_softmax = au_softmax_loss(local_aus_output, au, weight=au_weight)
            loss_local_au_dice = au_dice_loss(local_aus_output, au, weight=au_weight)
            loss_local_au = loss_local_au_softmax + loss_local_au_dice

            loss_land = landmark_loss(align_output, land, biocular)

            total_loss = config.lambda_au * (loss_au + loss_local_au) + \
                         config.lambda_land * loss_land

            total_loss.backward()
            optimizer.step()

            count = count + 1

    res_file.close()
Exemplo n.º 3
0
def main(config):
    ## set loss criterion
    use_gpu = torch.cuda.is_available()
    au_weight_src = torch.from_numpy(
        np.loadtxt(config.src_train_path_prefix + '_weight.txt'))
    if use_gpu:
        au_weight_src = au_weight_src.float().cuda()
    else:
        au_weight_src = au_weight_src.float()

    au_class_criterion = nn.BCEWithLogitsLoss(au_weight_src)
    land_predict_criterion = land_softmax_loss
    discriminator_criterion = nn.MSELoss()
    reconstruct_criterion = nn.L1Loss()
    land_discriminator_criterion = land_discriminator_loss
    land_adaptation_criterion = land_adaptation_loss

    ## prepare data
    dsets = {}
    dset_loaders = {}
    dsets['source'] = {}
    dset_loaders['source'] = {}

    dsets['source']['train'] = ImageList_land_au(
        config.crop_size,
        config.src_train_path_prefix,
        transform=prep.image_train(crop_size=config.crop_size),
        target_transform=prep.land_transform(
            output_size=config.output_size,
            scale=config.crop_size / config.output_size,
            flip_reflect=np.loadtxt(config.flip_reflect)))

    dset_loaders['source']['train'] = util_data.DataLoader(
        dsets['source']['train'],
        batch_size=config.train_batch_size,
        shuffle=True,
        num_workers=config.num_workers)

    dsets['source']['val'] = ImageList_au(
        config.src_val_path_prefix,
        transform=prep.image_test(crop_size=config.crop_size))

    dset_loaders['source']['val'] = util_data.DataLoader(
        dsets['source']['val'],
        batch_size=config.eval_batch_size,
        shuffle=False,
        num_workers=config.num_workers)

    dsets['target'] = {}
    dset_loaders['target'] = {}

    dsets['target']['train'] = ImageList_land_au(
        config.crop_size,
        config.tgt_train_path_prefix,
        transform=prep.image_train(crop_size=config.crop_size),
        target_transform=prep.land_transform(
            output_size=config.output_size,
            scale=config.crop_size / config.output_size,
            flip_reflect=np.loadtxt(config.flip_reflect)))

    dset_loaders['target']['train'] = util_data.DataLoader(
        dsets['target']['train'],
        batch_size=config.train_batch_size,
        shuffle=True,
        num_workers=config.num_workers)

    dsets['target']['val'] = ImageList_au(
        config.tgt_val_path_prefix,
        transform=prep.image_test(crop_size=config.crop_size))

    dset_loaders['target']['val'] = util_data.DataLoader(
        dsets['target']['val'],
        batch_size=config.eval_batch_size,
        shuffle=False,
        num_workers=config.num_workers)

    ## set network modules
    base_net = network.network_dict[config.base_net]()
    land_enc = network.network_dict[config.land_enc](land_num=config.land_num)
    land_enc_store = network.network_dict[config.land_enc](
        land_num=config.land_num)
    au_enc = network.network_dict[config.au_enc](au_num=config.au_num)
    invar_shape_enc = network.network_dict[config.invar_shape_enc]()
    feat_gen = network.network_dict[config.feat_gen]()
    invar_shape_disc = network.network_dict[config.invar_shape_disc](
        land_num=config.land_num)
    feat_gen_disc_src = network.network_dict[config.feat_gen_disc]()
    feat_gen_disc_tgt = network.network_dict[config.feat_gen_disc]()

    if config.start_epoch > 0:
        base_net.load_state_dict(
            torch.load(config.write_path_prefix + config.mode + '/base_net_' +
                       str(config.start_epoch) + '.pth'))
        land_enc.load_state_dict(
            torch.load(config.write_path_prefix + config.mode + '/land_enc_' +
                       str(config.start_epoch) + '.pth'))
        au_enc.load_state_dict(
            torch.load(config.write_path_prefix + config.mode + '/au_enc_' +
                       str(config.start_epoch) + '.pth'))
        invar_shape_enc.load_state_dict(
            torch.load(config.write_path_prefix + config.mode +
                       '/invar_shape_enc_' + str(config.start_epoch) + '.pth'))
        feat_gen.load_state_dict(
            torch.load(config.write_path_prefix + config.mode + '/feat_gen_' +
                       str(config.start_epoch) + '.pth'))
        invar_shape_disc.load_state_dict(
            torch.load(config.write_path_prefix + config.mode +
                       '/invar_shape_disc_' + str(config.start_epoch) +
                       '.pth'))
        feat_gen_disc_src.load_state_dict(
            torch.load(config.write_path_prefix + config.mode +
                       '/feat_gen_disc_src_' + str(config.start_epoch) +
                       '.pth'))
        feat_gen_disc_tgt.load_state_dict(
            torch.load(config.write_path_prefix + config.mode +
                       '/feat_gen_disc_tgt_' + str(config.start_epoch) +
                       '.pth'))

    if use_gpu:
        base_net = base_net.cuda()
        land_enc = land_enc.cuda()
        land_enc_store = land_enc_store.cuda()
        au_enc = au_enc.cuda()
        invar_shape_enc = invar_shape_enc.cuda()
        feat_gen = feat_gen.cuda()
        invar_shape_disc = invar_shape_disc.cuda()
        feat_gen_disc_src = feat_gen_disc_src.cuda()
        feat_gen_disc_tgt = feat_gen_disc_tgt.cuda()

    ## collect parameters
    base_net_parameter_list = [{
        'params':
        filter(lambda p: p.requires_grad, base_net.parameters()),
        'lr':
        1
    }]
    land_enc_parameter_list = [{
        'params':
        filter(lambda p: p.requires_grad, land_enc.parameters()),
        'lr':
        1
    }]
    au_enc_parameter_list = [{
        'params':
        filter(lambda p: p.requires_grad, au_enc.parameters()),
        'lr':
        1
    }]
    invar_shape_enc_parameter_list = [{
        'params':
        filter(lambda p: p.requires_grad, invar_shape_enc.parameters()),
        'lr':
        1
    }]
    feat_gen_parameter_list = [{
        'params':
        filter(lambda p: p.requires_grad, feat_gen.parameters()),
        'lr':
        1
    }]
    invar_shape_disc_parameter_list = [{
        'params':
        filter(lambda p: p.requires_grad, invar_shape_disc.parameters()),
        'lr':
        1
    }]
    feat_gen_disc_src_parameter_list = [{
        'params':
        filter(lambda p: p.requires_grad, feat_gen_disc_src.parameters()),
        'lr':
        1
    }]
    feat_gen_disc_tgt_parameter_list = [{
        'params':
        filter(lambda p: p.requires_grad, feat_gen_disc_tgt.parameters()),
        'lr':
        1
    }]

    ## set optimizer
    Gen_optimizer = optim_dict[config.gen_optimizer_type](itertools.chain(
        invar_shape_enc_parameter_list,
        feat_gen_parameter_list), 1.0, [config.gen_beta1, config.gen_beta2])
    Task_optimizer = optim_dict[config.task_optimizer_type](itertools.chain(
        base_net_parameter_list, land_enc_parameter_list,
        au_enc_parameter_list), 1.0, [config.task_beta1, config.task_beta2])
    Disc_optimizer = optim_dict[config.gen_optimizer_type](
        itertools.chain(invar_shape_disc_parameter_list,
                        feat_gen_disc_src_parameter_list,
                        feat_gen_disc_tgt_parameter_list), 1.0,
        [config.gen_beta1, config.gen_beta2])

    Gen_param_lr = []
    for param_group in Gen_optimizer.param_groups:
        Gen_param_lr.append(param_group['lr'])

    Task_param_lr = []
    for param_group in Task_optimizer.param_groups:
        Task_param_lr.append(param_group['lr'])

    Disc_param_lr = []
    for param_group in Disc_optimizer.param_groups:
        Disc_param_lr.append(param_group['lr'])

    Gen_lr_scheduler = lr_schedule.schedule_dict[config.gen_lr_type]
    Task_lr_scheduler = lr_schedule.schedule_dict[config.task_lr_type]
    Disc_lr_scheduler = lr_schedule.schedule_dict[config.gen_lr_type]

    print(base_net, land_enc, au_enc, invar_shape_enc, feat_gen)
    print(invar_shape_disc, feat_gen_disc_src, feat_gen_disc_tgt)

    if not os.path.exists(config.write_path_prefix + config.mode):
        os.makedirs(config.write_path_prefix + config.mode)
    if not os.path.exists(config.write_res_prefix + config.mode):
        os.makedirs(config.write_res_prefix + config.mode)

    val_type = 'target'  # 'source'
    res_file = open(
        config.write_res_prefix + config.mode + '/' + val_type + '_AU_pred_' +
        str(config.start_epoch) + '.txt', 'w')

    ## train
    len_train_tgt = len(dset_loaders['target']['train'])
    count = 0

    for epoch in range(config.start_epoch, config.n_epochs + 1):
        # eval in the train
        if epoch >= config.start_epoch:
            base_net.train(False)
            land_enc.train(False)
            au_enc.train(False)
            invar_shape_enc.train(False)
            feat_gen.train(False)
            if val_type == 'source':
                f1score_arr, acc_arr = AU_detection_eval_src(
                    dset_loaders[val_type]['val'],
                    base_net,
                    au_enc,
                    use_gpu=use_gpu)
            else:
                f1score_arr, acc_arr = AU_detection_eval_tgt(
                    dset_loaders[val_type]['val'],
                    base_net,
                    land_enc,
                    au_enc,
                    invar_shape_enc,
                    feat_gen,
                    use_gpu=use_gpu)

            print('epoch =%d, f1 score mean=%f, accuracy mean=%f' %
                  (epoch, f1score_arr.mean(), acc_arr.mean()))
            print >> res_file, '%d\t%f\t%f' % (epoch, f1score_arr.mean(),
                                               acc_arr.mean())
            base_net.train(True)
            land_enc.train(True)
            au_enc.train(True)
            invar_shape_enc.train(True)
            feat_gen.train(True)

        if epoch > config.start_epoch:
            print('taking snapshot ...')
            torch.save(
                base_net.state_dict(), config.write_path_prefix + config.mode +
                '/base_net_' + str(epoch) + '.pth')
            torch.save(
                land_enc.state_dict(), config.write_path_prefix + config.mode +
                '/land_enc_' + str(epoch) + '.pth')
            torch.save(
                au_enc.state_dict(), config.write_path_prefix + config.mode +
                '/au_enc_' + str(epoch) + '.pth')
            torch.save(
                invar_shape_enc.state_dict(), config.write_path_prefix +
                config.mode + '/invar_shape_enc_' + str(epoch) + '.pth')
            torch.save(
                feat_gen.state_dict(), config.write_path_prefix + config.mode +
                '/feat_gen_' + str(epoch) + '.pth')
            torch.save(
                invar_shape_disc.state_dict(), config.write_path_prefix +
                config.mode + '/invar_shape_disc_' + str(epoch) + '.pth')
            torch.save(
                feat_gen_disc_src.state_dict(), config.write_path_prefix +
                config.mode + '/feat_gen_disc_src_' + str(epoch) + '.pth')
            torch.save(
                feat_gen_disc_tgt.state_dict(), config.write_path_prefix +
                config.mode + '/feat_gen_disc_tgt_' + str(epoch) + '.pth')

        if epoch >= config.n_epochs:
            break

        for i, batch_src in enumerate(dset_loaders['source']['train']):
            if i % config.display == 0 and count > 0:
                print(
                    '[epoch = %d][iter = %d][loss_disc = %f][loss_invar_shape_disc = %f][loss_gen_disc = %f][total_loss = %f][loss_invar_shape_adaptation = %f][loss_gen_adaptation = %f][loss_self_recons = %f][loss_gen_cycle = %f][loss_au = %f][loss_land = %f]'
                    %
                    (epoch, i, loss_disc.data.cpu().numpy(),
                     loss_invar_shape_disc.data.cpu().numpy(),
                     loss_gen_disc.data.cpu().numpy(),
                     total_loss.data.cpu().numpy(),
                     loss_invar_shape_adaptation.data.cpu().numpy(),
                     loss_gen_adaptation.data.cpu().numpy(),
                     loss_self_recons.data.cpu().numpy(),
                     loss_gen_cycle.data.cpu().numpy(),
                     loss_au.data.cpu().numpy(), loss_land.data.cpu().numpy()))

                print('learning rate = %f, %f, %f' %
                      (Disc_optimizer.param_groups[0]['lr'],
                       Gen_optimizer.param_groups[0]['lr'],
                       Task_optimizer.param_groups[0]['lr']))
                print('the number of training iterations is %d' % (count))

            input_src, land_src, au_src = batch_src
            if count % len_train_tgt == 0:
                if count > 0:
                    dset_loaders['target']['train'] = util_data.DataLoader(
                        dsets['target']['train'],
                        batch_size=config.train_batch_size,
                        shuffle=True,
                        num_workers=config.num_workers)
                iter_data_tgt = iter(dset_loaders['target']['train'])
            input_tgt, land_tgt, au_tgt = iter_data_tgt.next()

            if input_tgt.size(0) > input_src.size(0):
                input_tgt, land_tgt, au_tgt = input_tgt[
                    0:input_src.size(0), :, :, :], land_tgt[
                        0:input_src.size(0), :], au_tgt[0:input_src.size(0)]
            elif input_tgt.size(0) < input_src.size(0):
                input_src, land_src, au_src = input_src[
                    0:input_tgt.size(0), :, :, :], land_src[
                        0:input_tgt.size(0), :], au_src[0:input_tgt.size(0)]

            if use_gpu:
                input_src, land_src, au_src, input_tgt, land_tgt, au_tgt = \
                    input_src.cuda(), land_src.long().cuda(), au_src.float().cuda(), \
                    input_tgt.cuda(), land_tgt.long().cuda(), au_tgt.float().cuda()
            else:
                land_src, au_src, land_tgt, au_tgt = \
                    land_src.long(), au_src.float(), land_tgt.long(), au_tgt.float()

            land_enc_store.load_state_dict(land_enc.state_dict())

            base_feat_src = base_net(input_src)
            align_attention_src, align_feat_src, align_output_src = land_enc(
                base_feat_src)
            au_feat_src, au_output_src = au_enc(base_feat_src)

            base_feat_tgt = base_net(input_tgt)
            align_attention_tgt, align_feat_tgt, align_output_tgt = land_enc(
                base_feat_tgt)
            au_feat_tgt, au_output_tgt = au_enc(base_feat_tgt)

            invar_shape_output_src = invar_shape_enc(base_feat_src.detach())
            invar_shape_output_tgt = invar_shape_enc(base_feat_tgt.detach())

            # new_gen
            new_gen_tgt = feat_gen(align_attention_src.detach(),
                                   invar_shape_output_tgt)
            new_gen_src = feat_gen(align_attention_tgt.detach(),
                                   invar_shape_output_src)

            # recons_gen
            recons_gen_src = feat_gen(align_attention_src.detach(),
                                      invar_shape_output_src)
            recons_gen_tgt = feat_gen(align_attention_tgt.detach(),
                                      invar_shape_output_tgt)

            # new2_gen
            new_gen_invar_shape_output_src = invar_shape_enc(
                new_gen_src.detach())
            new_gen_invar_shape_output_tgt = invar_shape_enc(
                new_gen_tgt.detach())
            new_gen_align_attention_src, new_gen_align_feat_src, new_gen_align_output_src = land_enc_store(
                new_gen_src)
            new_gen_align_attention_tgt, new_gen_align_feat_tgt, new_gen_align_output_tgt = land_enc_store(
                new_gen_tgt)
            new2_gen_tgt = feat_gen(new_gen_align_attention_src.detach(),
                                    new_gen_invar_shape_output_tgt)
            new2_gen_src = feat_gen(new_gen_align_attention_tgt.detach(),
                                    new_gen_invar_shape_output_src)

            ############################
            # 1. train discriminator #
            ############################
            Disc_optimizer = Disc_lr_scheduler(Disc_param_lr, Disc_optimizer,
                                               epoch, config.n_epochs, 1,
                                               config.decay_start_epoch,
                                               config.gen_lr)
            Disc_optimizer.zero_grad()

            align_output_invar_shape_src = invar_shape_disc(
                invar_shape_output_src.detach())
            align_output_invar_shape_tgt = invar_shape_disc(
                invar_shape_output_tgt.detach())

            # loss_invar_shape_disc
            loss_base_invar_shape_disc_src = land_discriminator_criterion(
                align_output_invar_shape_src, land_src)
            loss_base_invar_shape_disc_tgt = land_discriminator_criterion(
                align_output_invar_shape_tgt, land_tgt)
            loss_invar_shape_disc = (loss_base_invar_shape_disc_src +
                                     loss_base_invar_shape_disc_tgt) * 0.5

            base_gen_src_pred = feat_gen_disc_src(base_feat_src.detach())
            new_gen_src_pred = feat_gen_disc_src(new_gen_src.detach())

            real_label = torch.ones((base_feat_src.size(0), 1))
            fake_label = torch.zeros((base_feat_src.size(0), 1))
            if use_gpu:
                real_label, fake_label = real_label.cuda(), fake_label.cuda()
            # loss_gen_disc_src
            loss_base_gen_src = discriminator_criterion(
                base_gen_src_pred, real_label)
            loss_new_gen_src = discriminator_criterion(new_gen_src_pred,
                                                       fake_label)
            loss_gen_disc_src = (loss_base_gen_src + loss_new_gen_src) * 0.5

            base_gen_tgt_pred = feat_gen_disc_tgt(base_feat_tgt.detach())
            new_gen_tgt_pred = feat_gen_disc_tgt(new_gen_tgt.detach())

            # loss_gen_disc_tgt
            loss_base_gen_tgt = discriminator_criterion(
                base_gen_tgt_pred, real_label)
            loss_new_gen_tgt = discriminator_criterion(new_gen_tgt_pred,
                                                       fake_label)
            loss_gen_disc_tgt = (loss_base_gen_tgt + loss_new_gen_tgt) * 0.5
            # loss_gen_disc
            loss_gen_disc = (loss_gen_disc_src + loss_gen_disc_tgt) * 0.5

            loss_disc = loss_invar_shape_disc + loss_gen_disc

            loss_disc.backward()

            # optimize discriminator
            Disc_optimizer.step()

            ############################
            # 2. train base network #
            ############################
            Gen_optimizer = Gen_lr_scheduler(Gen_param_lr, Gen_optimizer,
                                             epoch, config.n_epochs, 1,
                                             config.decay_start_epoch,
                                             config.gen_lr)
            Gen_optimizer.zero_grad()
            Task_optimizer = Task_lr_scheduler(Task_param_lr, Task_optimizer,
                                               epoch, config.n_epochs, 1,
                                               config.decay_start_epoch,
                                               config.task_lr)
            Task_optimizer.zero_grad()

            align_output_invar_shape_src = invar_shape_disc(
                invar_shape_output_src)
            align_output_invar_shape_tgt = invar_shape_disc(
                invar_shape_output_tgt)

            # loss_invar_shape_adaptation
            loss_base_invar_shape_adaptation_src = land_adaptation_criterion(
                align_output_invar_shape_src)
            loss_base_invar_shape_adaptation_tgt = land_adaptation_criterion(
                align_output_invar_shape_tgt)
            loss_invar_shape_adaptation = (
                loss_base_invar_shape_adaptation_src +
                loss_base_invar_shape_adaptation_tgt) * 0.5

            new_gen_src_pred = feat_gen_disc_src(new_gen_src)
            loss_gen_adaptation_src = discriminator_criterion(
                new_gen_src_pred, real_label)

            new_gen_tgt_pred = feat_gen_disc_tgt(new_gen_tgt)
            loss_gen_adaptation_tgt = discriminator_criterion(
                new_gen_tgt_pred, real_label)
            # loss_gen_adaptation
            loss_gen_adaptation = (loss_gen_adaptation_src +
                                   loss_gen_adaptation_tgt) * 0.5

            loss_gen_cycle_src = reconstruct_criterion(new2_gen_src,
                                                       base_feat_src.detach())
            loss_gen_cycle_tgt = reconstruct_criterion(new2_gen_tgt,
                                                       base_feat_tgt.detach())
            # loss_gen_cycle
            loss_gen_cycle = (loss_gen_cycle_src + loss_gen_cycle_tgt) * 0.5

            loss_self_recons_src = reconstruct_criterion(
                recons_gen_src, base_feat_src.detach())
            loss_self_recons_tgt = reconstruct_criterion(
                recons_gen_tgt, base_feat_tgt.detach())
            # loss_self_recons
            loss_self_recons = (loss_self_recons_src +
                                loss_self_recons_tgt) * 0.5

            loss_base_gen_au_src = au_class_criterion(au_output_src, au_src)
            loss_base_gen_au_tgt = au_class_criterion(au_output_tgt, au_tgt)
            loss_base_gen_land_src = land_predict_criterion(
                align_output_src, land_src)
            loss_base_gen_land_tgt = land_predict_criterion(
                align_output_tgt, land_tgt)

            new_gen_au_feat_src, new_gen_au_output_src = au_enc(new_gen_src)
            new_gen_au_feat_tgt, new_gen_au_output_tgt = au_enc(new_gen_tgt)
            loss_new_gen_au_src = au_class_criterion(new_gen_au_output_src,
                                                     au_tgt)
            loss_new_gen_au_tgt = au_class_criterion(new_gen_au_output_tgt,
                                                     au_src)
            loss_new_gen_land_src = land_predict_criterion(
                new_gen_align_output_src, land_tgt)
            loss_new_gen_land_tgt = land_predict_criterion(
                new_gen_align_output_tgt, land_src)

            # loss_land
            loss_land = (loss_base_gen_land_src + loss_base_gen_land_tgt +
                         loss_new_gen_land_src + loss_new_gen_land_tgt) * 0.5
            # loss_au
            if config.mode == 'weak':
                loss_au = (loss_base_gen_au_src + loss_new_gen_au_tgt) * 0.5
            else:
                loss_au = (loss_base_gen_au_src + loss_base_gen_au_tgt +
                           loss_new_gen_au_src + loss_new_gen_au_tgt) * 0.25

            total_loss = config.lambda_land_adv * loss_invar_shape_adaptation + \
                         config.lambda_feat_adv * loss_gen_adaptation + \
                         config.lambda_cross_cycle * loss_gen_cycle + config.lambda_self_recons * loss_self_recons + \
                         config.lambda_au * loss_au + config.lambda_land * loss_land

            total_loss.backward()
            Gen_optimizer.step()
            Task_optimizer.step()

            count = count + 1

    res_file.close()
Exemplo n.º 4
0
def main(config):
    ## set loss criterion
    use_gpu = torch.cuda.is_available()

    ## prepare data
    dsets = {}
    dset_loaders = {}
    dsets['test'] = ImageList(
        crop_size=config.crop_size,
        path=config.test_path_prefix,
        phase='test',
        transform=prep.image_test(crop_size=config.crop_size),
        target_transform=prep.land_transform(img_size=config.crop_size,
                                             flip_reflect=np.loadtxt(
                                                 config.flip_reflect)))

    dset_loaders['test'] = util_data.DataLoader(
        dsets['test'],
        batch_size=config.eval_batch_size,
        shuffle=False,
        num_workers=config.num_workers)

    ## set network modules
    region_learning = network.network_dict[config.region_learning](
        input_dim=3, unit_dim=config.unit_dim)
    align_net = network.network_dict[config.align_net](
        crop_size=config.crop_size,
        map_size=config.map_size,
        au_num=config.au_num,
        land_num=config.land_num,
        input_dim=config.unit_dim * 8)
    local_attention_refine = network.network_dict[
        config.local_attention_refine](au_num=config.au_num,
                                       unit_dim=config.unit_dim)
    local_au_net = network.network_dict[config.local_au_net](
        au_num=config.au_num,
        input_dim=config.unit_dim * 8,
        unit_dim=config.unit_dim)
    global_au_feat = network.network_dict[config.global_au_feat](
        input_dim=config.unit_dim * 8, unit_dim=config.unit_dim)
    au_net = network.network_dict[config.au_net](au_num=config.au_num,
                                                 input_dim=12000,
                                                 unit_dim=config.unit_dim)

    if use_gpu:
        region_learning = region_learning.cuda()
        align_net = align_net.cuda()
        local_attention_refine = local_attention_refine.cuda()
        local_au_net = local_au_net.cuda()
        global_au_feat = global_au_feat.cuda()
        au_net = au_net.cuda()

    if not os.path.exists(config.write_path_prefix + config.run_name):
        os.makedirs(config.write_path_prefix + config.run_name)
    if not os.path.exists(config.write_res_prefix + config.run_name):
        os.makedirs(config.write_res_prefix + config.run_name)

    if config.start_epoch <= 0:
        raise (RuntimeError('start_epoch should be larger than 0\n'))

    res_file = open(
        config.write_res_prefix + config.run_name + '/' + config.prefix +
        'offline_AU_pred_' + str(config.start_epoch) + '.txt', 'w')
    region_learning.train(False)
    align_net.train(False)
    local_attention_refine.train(False)
    local_au_net.train(False)
    global_au_feat.train(False)
    au_net.train(False)

    for epoch in range(config.start_epoch, config.n_epochs + 1):
        region_learning.load_state_dict(
            torch.load(config.write_path_prefix + config.run_name +
                       '/region_learning_' + str(epoch) + '.pth'))
        align_net.load_state_dict(
            torch.load(config.write_path_prefix + config.run_name +
                       '/align_net_' + str(epoch) + '.pth'))
        local_attention_refine.load_state_dict(
            torch.load(config.write_path_prefix + config.run_name +
                       '/local_attention_refine_' + str(epoch) + '.pth'))
        local_au_net.load_state_dict(
            torch.load(config.write_path_prefix + config.run_name +
                       '/local_au_net_' + str(epoch) + '.pth'))
        global_au_feat.load_state_dict(
            torch.load(config.write_path_prefix + config.run_name +
                       '/global_au_feat_' + str(epoch) + '.pth'))
        au_net.load_state_dict(
            torch.load(config.write_path_prefix + config.run_name +
                       '/au_net_' + str(epoch) + '.pth'))

        if config.pred_AU:
            local_f1score_arr, local_acc_arr, f1score_arr, acc_arr, mean_error, failure_rate = AU_detection_evalv2(
                dset_loaders['test'],
                region_learning,
                align_net,
                local_attention_refine,
                local_au_net,
                global_au_feat,
                au_net,
                use_gpu=use_gpu)
            print(
                'epoch =%d, local f1 score mean=%f, local accuracy mean=%f, '
                'f1 score mean=%f, accuracy mean=%f, mean error=%f, failure rate=%f'
                %
                (epoch, local_f1score_arr.mean(), local_acc_arr.mean(),
                 f1score_arr.mean(), acc_arr.mean(), mean_error, failure_rate))
            print(
                '%d\t%f\t%f\t%f\t%f\t%f\t%f' %
                (epoch, local_f1score_arr.mean(), local_acc_arr.mean(),
                 f1score_arr.mean(), acc_arr.mean(), mean_error, failure_rate),
                file=res_file)
        if config.vis_attention:
            if not os.path.exists(config.write_res_prefix + config.run_name +
                                  '/vis_map/' + str(epoch)):
                os.makedirs(config.write_res_prefix + config.run_name +
                            '/vis_map/' + str(epoch))
            if not os.path.exists(config.write_res_prefix + config.run_name +
                                  '/overlay_vis_map/' + str(epoch)):
                os.makedirs(config.write_res_prefix + config.run_name +
                            '/overlay_vis_map/' + str(epoch))

            vis_attention(dset_loaders['test'],
                          region_learning,
                          align_net,
                          local_attention_refine,
                          config.write_res_prefix,
                          config.run_name,
                          epoch,
                          use_gpu=use_gpu)

    res_file.close()
Exemplo n.º 5
0
def perform_operation(file_path, operation, epoch):
    if operation == 'Train':
        torch.enable_grad()
        e_net.train()
        a_net.train()
        s_net.train()
        fusion.train()
    else:
        torch.no_grad()
        e_net.eval()
        a_net.eval()
        s_net.eval()
        fusion.eval()

    if operation == 'Train':
        imDataset = ImageList(
            crop_size=args.IM_SIZE,
            path=file_path,
            img_path=args.img_path,
            NUM_CLASS=args.NUM_CLASS,
            phase='test',
            transform=prep.image_test(crop_size=args.IM_SIZE),
            target_transform=prep.land_transform(img_size=args.IM_SIZE))
        imDataLoader = torch.utils.data.DataLoader(imDataset,
                                                   batch_size=args.Train_BATCH,
                                                   shuffle=True,
                                                   num_workers=0)
    else:
        imDataset = ImageList(
            crop_size=args.IM_SIZE,
            path=file_path,
            img_path=args.img_path,
            NUM_CLASS=args.NUM_CLASS,
            phase='test',
            transform=prep.image_test(crop_size=args.IM_SIZE),
            target_transform=prep.land_transform(img_size=args.IM_SIZE))
        imDataLoader = torch.utils.data.DataLoader(imDataset,
                                                   batch_size=args.Test_BATCH,
                                                   num_workers=0)

    for batch_Idx, data in enumerate(imDataLoader):
        if operation == 'Train':
            print('%s Epoch: %d Batch_Idx: %d' % (operation, epoch, batch_Idx))

        if operation == 'Train':
            optimizer.zero_grad()

        datablob, datalb, pos_para = data

        datablob = torch.autograd.Variable(datablob).cuda()
        y_lb = torch.autograd.Variable(datalb).view(datalb.size(0), -1).cuda()
        pos_para = torch.autograd.Variable(pos_para).cuda()

        bceLoss_cls = nn.BCEWithLogitsLoss()
        bceLoss2_att = nn.BCEWithLogitsLoss()

        pred_global = e_net(datablob)
        feat_data = e_net.predict_BN(datablob)
        pred_att_map, pred_conf = a_net(feat_data)
        slice_feat_data = prep_model_input(pred_att_map, pos_para)
        pred_local = s_net(slice_feat_data)
        cls_pred = fusion(pred_global + pred_local)

        cls_loss = bceLoss_cls(cls_pred, y_lb)
        att_loss = bceLoss2_att(pred_conf, y_lb)
        sum_loss = cls_loss + att_loss

        if operation == 'Train':
            sum_loss.backward()

        cls_pred = cls_pred.data.cpu().float()
        y_lb = y_lb.data.cpu().float()
        f1_score = get_f1(cls_pred, y_lb)
        acc_scr = get_acc(cls_pred, y_lb)

        if operation == 'Test':
            if batch_Idx == 0:
                all_output = cls_pred
                all_label = y_lb
            else:
                all_output = torch.cat((all_output, cls_pred), 0)
                all_label = torch.cat((all_label, y_lb), 0)

        if operation == 'Train':
            print('acc_scr',
                  acc_scr.mean().cpu().data.item(), 'f1_score',
                  f1_score.mean().cpu().data.item(), 'sum_loss',
                  sum_loss.cpu().data.item())

        if operation == 'Train':
            optimizer.step()

        if operation == 'Test':
            fout_test.write('Label:' + str(y_lb) + '->' + 'Pre:' +
                            str(cls_pred) + '\n')

        del datablob, y_lb, pos_para, feat_data, pred_att_map, pred_conf, slice_feat_data, pred_local, cls_pred, cls_loss, att_loss, sum_loss, acc_scr, f1_score

    if operation == 'Test':
        all_acc_scr = get_acc(all_output, all_label)
        all_f1_score = get_f1(all_output, all_label)

        fout_test_f1.write('***' + str(all_f1_score.numpy().tolist()) + '\n')
        fout_test_f1_mean.write('***' +
                                str(all_f1_score.mean().numpy().tolist()) +
                                '\n')
        fout_test_acc.write('***' + str(all_acc_scr.numpy().tolist()) + '\n')
        fout_test_acc_mean.write('***' +
                                 str(all_acc_scr.mean().numpy().tolist()) +
                                 '\n')

        print('average f1 score: ', str(all_f1_score.mean().numpy().tolist()))
        print('average acc score: ', str(all_acc_scr.mean().numpy().tolist()))

        del all_acc_scr, all_f1_score, all_output, all_label

    if operation == 'Train':
        new_model = './result/snap/' + args.version + '/WS-DAFNet_' + args.name + '_' + str(
            epoch) + '.pth'
        torch.save([e_net, a_net, s_net, fusion], new_model)
        print('save ' + new_model)