예제 #1
0
파일: trainer.py 프로젝트: XsLangley/TDGAN
def train(hpar_dict):

    # region: hyper-parameters for the model

    # noise channel
    Nz = hpar_dict['Nz']

    # steps gap to update discriminator
    D_GAP_FR = hpar_dict['D_GAP_FR']
    D_GAP_ER = hpar_dict['D_GAP_ER']
    # steps gap for saving images
    IMG_SAVE_GAP = hpar_dict['IMG_SAVE_GAP']
    # gaps to save parameters (epochs)
    PAR_SAVE_GAP = hpar_dict['PAR_SAVE_GAP']
    # validation gap
    VAL_GAP = hpar_dict['VAL_GAP']
    # batch size
    BS = hpar_dict['BS']
    # training epochs
    epoch = hpar_dict['epoch']
    # face recognition class number
    FR_cls_num = hpar_dict['FR_cls_num']

    # learning rate
    LR_D_FR = hpar_dict['LR_D_FR']
    LR_D_ER = hpar_dict['LR_D_ER']
    LR_G_FR = hpar_dict['LR_G_FR']
    LR_G_ER = hpar_dict['LR_G_ER']
    # weights to balance the loss of generator or discriminator
    H_G_FR_f = hpar_dict['H_G_FR_f']
    H_G_ER_f = hpar_dict['H_G_ER_f']
    H_G_FR_PER = hpar_dict['H_G_FR_PER']
    H_G_ER_PER = hpar_dict['H_G_ER_PER']
    H_G_CON_FR = hpar_dict['H_G_CON_FR']
    H_G_CON_ER = hpar_dict['H_G_CON_ER']

    H_D_FR_r = hpar_dict['H_D_FR_r']
    H_D_FR_f = hpar_dict['H_D_FR_f']
    H_D_ER_r = hpar_dict['H_D_ER_r']
    H_D_ER_f = hpar_dict['H_D_ER_f']
    # flag to indicate whether to generate grayscale images
    FLAG_GEN_GRAYIMG = hpar_dict['FLAG_GEN_GRAYIMG']
    # dataset
    ER_DB = hpar_dict['ER_DB']
    FR_DB = hpar_dict['FR_DB']
    save_dir = hpar_dict['save_dir']
    # directory to save images
    img_dir = os.path.join(save_dir, 'img')
    if not os.path.exists(img_dir):
        os.makedirs(img_dir)
    # directory to save parameters
    par_dir = os.path.join(save_dir, 'par')
    if not os.path.exists(par_dir):
        os.makedirs(par_dir)

    TRAIN_FLAG = hpar_dict['train']

    # endregion

    # region: construct dataloaders for training

    # construct face dataset's dataloader.
    if FR_DB == 'CASIA':
        # load the npz file that stores images' data (dir, label)
        face_data_dir = './Dataset/CASIA_WebFace/casia_data_example.npz'
        face_img_dir_list, face_img_lab_list = LoadData.getFaceData(
            face_data_dir, FR_cls_num)
        # the root directory of where images are stored (in your local machine)
        face_root_dir = './Dataset/CASIA_WebFace/img/'

    if ER_DB == 'RAF':
        ER_cls_num = 7  # the class of expressions inn RAF-DB
        acc_max = 0.1

        # npy file that stores the data of RAF-DB (img name, labels)
        expr_data_dir = './Dataset/RAF/RAF_example_label.npy'

        expr_root_dir = './Dataset/RAF/img/'
        lab_dir = expr_data_dir
        # the dataset instance for testing
        dataset_test = LoadData.RAFDataset(expr_root_dir,
                                           lab_dir,
                                           RGB_flag=not FLAG_GEN_GRAYIMG,
                                           train=False)

    # dataset instance for training
    dataset_train = LoadData.DualTrainDatasetRAF(face_root_dir,
                                                 face_img_dir_list,
                                                 face_img_lab_list,
                                                 expr_data_dir,
                                                 expr_root_dir,
                                                 GRAY_flag=FLAG_GEN_GRAYIMG)

    # dataloader for training, which contains two datasets
    train_loader = torch.utils.data.DataLoader(dataset_train,
                                               batch_size=BS,
                                               shuffle=True,
                                               num_workers=2)
    # dataloader for testing, which only contains the expression dataset
    expr_tt_loader = torch.utils.data.DataLoader(dataset_test,
                                                 batch_size=BS,
                                                 shuffle=False,
                                                 num_workers=2)

    # endregion

    # region: module instantiate
    # instantiate Generator
    Gen = model.Gen(clsn_ER=ER_cls_num, Nz=Nz, GRAY=FLAG_GEN_GRAYIMG, Nb=6)

    # instantiate face discriminator
    Dis_FR = model.Dis(GRAY=FLAG_GEN_GRAYIMG, cls_num=FR_cls_num + 1)
    # instantiate expression discriminator
    Dis_ER = model.Dis(GRAY=FLAG_GEN_GRAYIMG, cls_num=ER_cls_num)

    # instantiate Expression Clssification Module (M_ER)
    Dis_ER_val = model.Dis()
    Dis_ER_val.enc = Gen.enc_ER
    Dis_ER_val.fc = Gen.fc_ER

    # push model into GPU
    Gen.to(hpar_dict['device'])
    Dis_FR.to(hpar_dict['device'])
    Dis_ER.to(hpar_dict['device'])
    Dis_ER_val.to(hpar_dict['device'])

    # endregion

    # region optimizer definition

    # parameters of the generator
    par_list_G_joint = [{
        'params': Gen.dec.parameters(),
        'lr': LR_G_ER
    }, {
        'params': Gen.enc_FR.parameters(),
        'lr': LR_G_FR
    }, {
        'params': Gen.enc_ER.parameters(),
        'lr': LR_G_ER
    }]
    # parameters of the Expression Recognition Module
    par_list_G_ER_fc = [
        {
            'params': Gen.fc_ER.parameters(),
            'lr': LR_D_ER
        },
    ]
    # parameters of the two discriminators
    par_list_D_FR = [
        {
            'params': Dis_FR.parameters(),
            'lr': LR_D_FR
        },
    ]
    par_list_D_ER = [
        {
            'params': Dis_ER.parameters(),
            'lr': LR_D_ER
        },
    ]

    optG_joint = optim.Adam(par_list_G_joint)
    optG_ER_fc = optim.Adam(par_list_G_ER_fc)
    optD_FR = optim.Adam(par_list_D_FR)
    optD_ER = optim.Adam(par_list_D_ER)

    # endregion

    # criterion for loss
    CE = nn.CrossEntropyLoss()
    MSE = nn.MSELoss()
    L1_loss = nn.L1Loss()

    if not TRAIN_FLAG:
        # region: load pretrained model and do forward steps

        # pretrained model dir
        pre_root_dir = './Dataset/examples'
        par_Enc_FR_dir = os.path.join(pre_root_dir, 'Enc_FR_G.pkl')
        par_Enc_ER_dir = os.path.join(pre_root_dir, 'Enc_ER_G.pkl')
        par_dec_dir = os.path.join(pre_root_dir, 'dec.pkl')
        par_fc_ER_dir = os.path.join(pre_root_dir, 'fc_ER_G.pkl')

        par_Dis_ER_dir = os.path.join(pre_root_dir, 'Dis_ER.pkl')
        # load parameters
        print('loading pretrained models......')
        util.del_extra_keys(par_Enc_FR_dir)
        Gen.enc_FR.load_state_dict(util.del_extra_keys(par_Enc_FR_dir))
        Gen.enc_ER.load_state_dict(util.del_extra_keys(par_Enc_ER_dir))
        Gen.dec.load_state_dict(util.del_extra_keys(par_dec_dir))
        Gen.fc_ER.load_state_dict(util.del_extra_keys(par_fc_ER_dir))
        Dis_ER_val.enc.load_state_dict(util.del_extra_keys(par_Enc_ER_dir))
        Dis_ER_val.fc.load_state_dict(util.del_extra_keys(par_fc_ER_dir))

        Dis_ER.load_state_dict(util.del_extra_keys(par_Dis_ER_dir))

        # load example images for validation demo
        face_img_dir = '{}/face.jpg'.format(pre_root_dir)
        expression_img_dir = '{}/expression.jpg'.format(pre_root_dir)
        faceimg = util.preprocess_img(face_img_dir, hpar_dict['device'])
        exprimg = util.preprocess_img(expression_img_dir, hpar_dict['device'])

        with torch.no_grad():
            Gen.eval()
            gen_img = Gen.gen_img(faceimg, exprimg, device=hpar_dict['device'])
            gen_img_copy = gen_img.detach()
            gen_img = gen_img.squeeze().cpu().data.numpy()
            img_PIL = Image.fromarray((gen_img * 255).astype(np.uint8))
            img_PIL.save('{}/generated_example.png'.format(pre_root_dir))

            # validation using M_ER
            Dis_ER_val.eval()
            util.Val_acc_single(gen_img_copy,
                                Dis_ER_val,
                                device=hpar_dict['device'],
                                name='M_ER')

            # validation using discriminator
            Dis_ER.eval()
            util.Val_acc_single(gen_img_copy,
                                Dis_ER,
                                device=hpar_dict['device'],
                                name='Expr Dis')

        # endregion
    else:

        # buffer to store validation accuracy (Expression Classification Module)
        tt_acc_mat = []
        tt_ce_mat = []
        # buffer to store validation accuracy (Expression discriminator)
        tt_acc_mat_ExpDis = []
        tt_ce_mat_ExpDis = []

        # start training
        for e in range(1, epoch + 1):
            print('---- training ----')
            # the number of steps that an epoch goes
            step_total = train_loader.__len__()
            t_start = time.time()
            print('the %d-th training epoch' % (e))

            # set training mode
            Gen.train()
            Dis_ER.train()
            Dis_FR.train()

            for step, (batch_FR_x_r, batch_FR_y_r, batch_ER_x_r,
                       batch_ER_y_r) in enumerate(train_loader):

                # region batch data preparation

                # batch_FR_x_r: real batch data for Face Recognition
                # batch_FR_y_r: real batch labels for Face Recognition
                # batch_ER_x_r: real batch data for expression Recognition
                # batch_ER_y_r: real batch labels for expression Recognition

                # labels for fake images
                batch_FR_y_f = FR_cls_num * torch.ones(
                    len(batch_FR_y_r)).long()

                # convert all tensors to the form of torch.Variables
                batch_FR_x_r = Variable(batch_FR_x_r).to(hpar_dict['device'])
                batch_FR_y_r = Variable(batch_FR_y_r).long().to(
                    hpar_dict['device'])

                batch_ER_x_r = Variable(batch_ER_x_r).to(hpar_dict['device'])
                batch_ER_y_r = Variable(batch_ER_y_r).long().to(
                    hpar_dict['device'])

                batch_FR_y_f = Variable(batch_FR_y_f).long().to(
                    hpar_dict['device'])

                # endregion

                # region forward step

                # go through the discriminators
                batch_FR_Dfea_r, batch_FR_Dp_r = Dis_FR(batch_FR_x_r)
                batch_FR_Dfea_r = Variable(batch_FR_Dfea_r.data,
                                           requires_grad=False)
                batch_ER_Dfea_r, batch_ER_Dp_r = Dis_ER(batch_ER_x_r)
                batch_ER_Dfea_r = Variable(batch_ER_Dfea_r.data,
                                           requires_grad=False)

                # loss of face discriminator (with respect to real samples)
                loss_D_FR_r = CE(batch_FR_Dp_r, batch_FR_y_r)
                # loss of expression discriminator (with respect to real samples)
                loss_D_ER_r = CE(batch_ER_Dp_r, batch_ER_y_r)

                # generat images
                batch_x_f = Gen.gen_img(batch_FR_x_r,
                                        batch_ER_x_r,
                                        device=hpar_dict['device'])
                batch_ER_Gfea_r = Variable(Gen.fea_ER.data,
                                           requires_grad=False)

                # region: update Expression Recognition Module
                optG_ER_fc.zero_grad()
                err_G_ER_r = CE(Gen.result_ER, batch_ER_y_r)
                err_G_ER_r.backward(retain_graph=True)
                optG_ER_fc.step()
                # endregion

                # endregion

                # region update discriminators

                # clear gradient buffer
                optD_FR.zero_grad()
                optD_ER.zero_grad()
                if step % D_GAP_FR == 0:
                    batch_FR_Dfea_f, batch_FR_Dp_f = Dis_FR(batch_x_f.detach())
                    # loss of face discriminator (with respect to fake samples)
                    loss_D_FR_f = CE(batch_FR_Dp_f, batch_FR_y_f)

                    # full loss of face discriminator
                    loss_D_FR = H_D_FR_r * loss_D_FR_r + H_D_FR_f * loss_D_FR_f

                    loss_D_FR.backward()
                    optD_FR.step()

                if step % D_GAP_ER == 0:

                    # loss of expression discriminator
                    loss_D_ER = H_D_ER_r * loss_D_ER_r

                    loss_D_ER.backward()
                    optD_ER.step()

                # endregion

                # region update generator

                optG_joint.zero_grad()

                # get the predicted results on fake samples
                batch_FR_Dfea_f, batch_FR_Dp_f = Dis_FR(batch_x_f)
                batch_ER_Dfea_f, batch_ER_Dp_f = Dis_ER(batch_x_f)

                err_G_FR_f = CE(batch_FR_Dp_f,
                                batch_FR_y_r)  # Equ.6 first part
                err_G_ER_f = CE(batch_ER_Dp_f,
                                batch_ER_y_r)  # Equ.6 second part
                err_G_FR_PER = MSE(batch_FR_Dfea_f, batch_FR_Dfea_r)  # Equ.8

                # consistency loss (Fig.3 upper part): face branch input: the generated image, expression branch input: the original face image, expected output: same as the original face image
                batch_x_f_FR = Gen.gen_img(batch_x_f,
                                           batch_FR_x_r,
                                           device=hpar_dict['device'])
                # consistency loss (Fig.3 lower part): face branch input: the original expression image, expression branch: the generated image, expected output: same as the original expression image
                batch_x_f_ER = Gen.gen_img(batch_ER_x_r,
                                           batch_x_f,
                                           device=hpar_dict['device'])

                # expression perceptual error (unused)
                batch_ER_Gfea_f = Variable(Gen.fea_ER.data).to(
                    hpar_dict['device'])
                err_G_ER_PER = MSE(batch_ER_Gfea_f, batch_ER_Gfea_r)

                err_G_con_FR = L1_loss(batch_x_f_FR, batch_FR_x_r)
                err_G_con_ER = L1_loss(batch_x_f_ER, batch_ER_x_r)
                err_G_con = H_G_CON_FR * err_G_con_FR + H_G_CON_ER * err_G_con_ER  # Equ.7
                loss_G = H_G_FR_f * err_G_FR_f + H_G_ER_f * err_G_ER_f + \
                        H_G_FR_PER * err_G_FR_PER + H_G_ER_PER * err_G_ER_PER + err_G_con

                loss_G.backward()
                optG_joint.step()

                # endregion

                if step % 5 == 0:
                    print('the current information of the model:')
                    print('%d / %d' % (step, step_total))
                    print('the loss of G (total): %f' % (loss_G.cpu().data))
                    print('the loss of G (face): %f' % (err_G_FR_f.cpu().data))
                    print('the loss of G (expression): %f' %
                          (err_G_ER_f.cpu().data))
                    print('the loss of G (face-per): %f' %
                          (err_G_FR_PER.cpu().data))
                    print('the loss of G (expr-per): %f' %
                          (err_G_ER_PER.cpu().data))
                    print('the loss of G (consistency): %f' %
                          (err_G_con.cpu().data))
                    print('the loss of G (FR-cons): %f' %
                          (err_G_con_FR.cpu().data))
                    print('the loss of G (ER-cons): %f' %
                          (err_G_con_ER.cpu().data))
                    print(
                        '------------------------------------------------------------'
                    )
                    print('the loss of D (face-total): %f' %
                          (loss_D_FR.cpu().data))
                    print('the loss of D (face-real): %f' %
                          (loss_D_FR_r.cpu().data))
                    print('the loss of D (face-fake): %f' %
                          (loss_D_FR_f.cpu().data))
                    print('the loss of D (expression): %f' %
                          (loss_D_ER.cpu().data))

                # save the generated images
                if step % IMG_SAVE_GAP == 0:
                    # combine five images of real face, real expression and fake images
                    comb_img = util.combinefig_dualcon(
                        batch_FR_x_r.cpu().data.numpy(),
                        batch_ER_x_r.cpu().data.numpy(),
                        batch_x_f.cpu().data.numpy(),
                        batch_x_f_FR.cpu().data.numpy(),
                        batch_x_f_ER.cpu().data.numpy())
                    # save figures
                    comb_img = Image.fromarray(
                        (comb_img * 255).astype(np.uint8))
                    comb_img.save(
                        os.path.join(img_dir,
                                     str(e) + '_' + str(step) + '.jpg'))

            t_end = time.time()
            print('an epoch last for %f seconds\n' % (t_end - t_start))

            # region validation

            if e % VAL_GAP == 0:
                Dis_ER_val.eval()
                tt_acc, tt_ce = util.Val_acc(expr_tt_loader,
                                             Dis_ER_val,
                                             CE,
                                             device=hpar_dict['device'])
                tt_acc_mat.append(tt_acc)
                tt_ce_mat.append(tt_ce)
                if tt_acc > acc_max:
                    acc_max = tt_acc
                    torch.save(Gen.enc_ER.state_dict(),
                               os.path.join(par_dir, 'Enc_ER_G.pkl'))
                    torch.save(Gen.enc_FR.state_dict(),
                               os.path.join(par_dir, 'Enc_FR_G.pkl'))
                    torch.save(Gen.fc_ER.state_dict(),
                               os.path.join(par_dir, 'fc_ER_G.pkl'))
                    torch.save(Gen.dec.state_dict(),
                               os.path.join(par_dir, 'dec.pkl'))
                    torch.save(Dis_FR.state_dict(),
                               os.path.join(par_dir, 'Dis_FR.pkl'))
                    torch.save(Dis_ER.state_dict(),
                               os.path.join(par_dir, 'Dis_ER.pkl'))
                print('\n')
                print('the %d-th epoch' % (e))
                print('accuracy is : %f' % (tt_acc))
                print('validation cross enntropy is : %f' % (tt_ce))
                print('now the best accuracy is %f\n' % (np.max(tt_acc_mat)))

                # validation using discriminator
                Dis_ER.eval()
                tt_acc_ExpDis, tt_ce_ExpDis = util.Val_acc(
                    expr_tt_loader, Dis_ER, CE, device=hpar_dict['device'])
                tt_acc_mat_ExpDis.append(tt_acc_ExpDis)
                tt_ce_mat_ExpDis.append(tt_ce_ExpDis)
                print('testing using discriminator:')
                print('accuracy is : %f' % (tt_acc_ExpDis))
                print('testing cross enntropy is : %f' % (tt_ce_ExpDis))
                print('now the best accuracy is %f\n' %
                      (np.max(tt_acc_mat_ExpDis)))

            if e % PAR_SAVE_GAP == 0:
                torch.save(
                    Gen.enc_ER.state_dict(),
                    os.path.join(par_dir, 'Enc_ER_G_' + str(e) + '.pkl'))
                torch.save(
                    Gen.enc_FR.state_dict(),
                    os.path.join(par_dir, 'Enc_FR_G_' + str(e) + '.pkl'))
                torch.save(Gen.fc_ER.state_dict(),
                           os.path.join(par_dir, 'fc_ER_G_' + str(e) + '.pkl'))
                torch.save(Gen.dec.state_dict(),
                           os.path.join(par_dir, 'dec_' + str(e) + '.pkl'))

                torch.save(Dis_FR.state_dict(),
                           os.path.join(par_dir, 'Dis_FR_' + str(e) + '.pkl'))
                torch.save(Dis_ER.state_dict(),
                           os.path.join(par_dir, 'Dis_ER_' + str(e) + '.pkl'))
            # endregion

        print('end')

        np.savez(os.path.join(save_dir, 'val_data.npz'),
                 tt_acc_mat=tt_acc_mat,
                 tt_ce_mat=tt_ce_mat,
                 tt_acc_mat_ExpDis=tt_acc_mat_ExpDis,
                 tt_ce_mat_ExpDis=tt_ce_mat_ExpDis)
        return tt_acc_mat, tt_ce_mat