Beispiel #1
0
    def infer_testset(self):
        dataset = datasets.Pose_300W_LP(data_dir='./300W_LP',
                                        filename_path=self.args.filename_list,
                                        batch_size=self.args.batch_size,
                                        image_size=self.args.image_size)

        for i in range(dataset.length // self.args.batch_size):
            batch_images, batch_labels, batch_cont_labels = dataset.get()
            feed_dict = {
                self.images: batch_images,
                self.is_training: True,
                self.keep_prob: 0.5
            }
            yaw, pitch, roll = self.sess.run([
                self.yaw_predicted, self.pitch_predicted, self.roll_predicted
            ],
                                             feed_dict=feed_dict)

            tf.logging.info('[] infer test!!!!!\nGT: {}\nP :{},{},{}'.format(
                batch_cont_labels[0], yaw[0], pitch[0], roll[0]))
            print('[] infer test!!!!!\nGT: {}\nP :{},{},{}'.format(
                batch_cont_labels[0], yaw[0], pitch[0], roll[0]))

            img = datasets.unnomalizing(batch_images[0], [0.485, 0.456, 0.406],
                                        [0.229, 0.224, 0.225])
            out = self.draw_axis(img.astype(np.uint8), yaw, pitch, roll)
            cv2.imshow("demo", out)
            cv2.waitKey(0)

        tf.logging.info('[*] test finished')
        print('[*] test finished')
    def __init__(self, data_dir, transformation, device, num_bin):

        args.data_dir = data_dir
        args.filename_list = data_dir
        args.dataset = 'AFLW2000'
        self.transformations = transformation
        self.device = device
        self.num_bin = num_bin

        idx_tensor = [idx for idx in xrange(67)]
        self.idx_tensor = torch.FloatTensor(idx_tensor).to(self.device)
        if args.dataset == 'Pose_300W_LP':
            pose_dataset = datasets.Pose_300W_LP(args.data_dir,
                                                 args.filename_list,
                                                 self.transformations)
        elif args.dataset == 'Pose_300W_LP_random_ds':
            pose_dataset = datasets.Pose_300W_LP_random_ds(
                args.data_dir, args.filename_list, self.transformations)
        elif args.dataset == 'AFLW2000':
            pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list,
                                             self.transformations)
        elif args.dataset == 'AFLW2000_ds':
            pose_dataset = datasets.AFLW2000_ds(args.data_dir,
                                                args.filename_list,
                                                self.transformations)
        elif args.dataset == 'BIWI':
            pose_dataset = datasets.BIWI(args.data_dir, args.filename_list,
                                         self.transformations)
        elif args.dataset == 'AFLW':
            pose_dataset = datasets.AFLW(args.data_dir, args.filename_list,
                                         self.transformations)
        elif args.dataset == 'AFLW_aug':
            pose_dataset = datasets.AFLW_aug(args.data_dir, args.filename_list,
                                             self.transformations)
        elif args.dataset == 'AFW':
            pose_dataset = datasets.AFW(args.data_dir, args.filename_list,
                                        self.transformations)
        else:
            print 'Error: not a valid dataset name'
            sys.exit()

        self.test_loader = torch.utils.data.DataLoader(
            dataset=pose_dataset, batch_size=args.batch_size, num_workers=2)
Beispiel #3
0
def main(_):
    args = parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
    
    images = tf.placeholder(dtype=tf.float32, shape = [None,args.image_size,args.image_size,3],name = 'image_data')
    labels = tf.placeholder(dtype=tf.int32, shape = [None,3], name = 'cls_label')
    keep_prob = tf.placeholder(dtype=tf.float32, shape = [], name = 'keep_prob')
    cont_labels = tf.placeholder(dtype=tf.float32, shape=[None, 3], name='cont_labels')
    is_training = tf.placeholder(tf.bool,name='is_training')
    num_bins = 66
    
    # Binned labels
    label_yaw = labels[:,0]
    label_pitch = labels[:,1]
    label_roll = labels[:,2]
    
    # Continuous labels
    label_yaw_cont = cont_labels[:,0]
    label_pitch_cont = cont_labels[:,1]
    label_roll_cont = cont_labels[:,2]
    
    with slim.arg_scope(nets.resnet_v1.resnet_arg_scope()):
        net,endpoints = nets.resnet_v1.resnet_v1_50(images, num_classes = None, is_training = is_training)
    
    with tf.variable_scope('Logits'):
        net = tf.squeeze(net,axis=[1,2])
        net = slim.dropout(net,keep_prob,scope='scope')
        yaw = slim.fully_connected(net, num_outputs = num_bins, activation_fn=None, scope='fc_yaw')
        pitch = slim.fully_connected(net, num_outputs = num_bins, activation_fn=None, scope='fc_pitch')
        roll = slim.fully_connected(net, num_outputs = num_bins, activation_fn=None, scope='fc_roll')
    
    # Cross entropy loss
    loss_yaw = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=tf.cast(label_yaw, tf.int64), logits=yaw)
    loss_pitch = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=tf.cast(label_pitch,tf.int64),logits=pitch)
    loss_roll = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=tf.cast(label_roll,tf.int64),logits=roll)
    
    loss_yaw_ce = tf.reduce_mean(loss_yaw)
    loss_pitch_ce = tf.reduce_mean(loss_pitch)
    loss_roll_ce = tf.reduce_mean(loss_roll)
    
    # MSE loss
    yaw_predicted = tf.nn.softmax(yaw)
    pitch_predicted =tf.nn.softmax(pitch)
    roll_predicted = tf.nn.softmax(roll)
    
    idx_tensor = [idx for idx in range(66)]
    idx_tensor = tf.convert_to_tensor(idx_tensor,dtype=tf.float32)
    
    yaw_predicted = tf.reduce_sum(yaw_predicted * idx_tensor, 1) * 3 - 99
    pitch_predicted = tf.reduce_sum(pitch_predicted * idx_tensor,1) * 3 - 99
    roll_predicted = tf.reduce_sum(roll_predicted * idx_tensor,1) * 3 - 99
    
    loss_reg_yaw = tf.reduce_mean(tf.square(yaw_predicted-label_yaw_cont))
    loss_reg_pitch = tf.reduce_mean(tf.square(pitch_predicted-label_pitch_cont))
    loss_reg_roll = tf.reduce_mean(tf.square(roll_predicted-label_roll_cont))

    # Total loss
    loss_yaw = tf.add_n([loss_yaw_ce,args.alpha * loss_reg_yaw])
    loss_pitch =tf.add_n([loss_pitch_ce, args.alpha * loss_reg_pitch])
    loss_roll = tf.add_n([loss_roll_ce,args.alpha * loss_reg_roll])
    
    loss_all = loss_yaw + loss_pitch + loss_roll
    
    global_step = tf.get_variable("step", [], initializer=tf.constant_initializer(0.0), trainable=False)
    rate = tf.train.exponential_decay(0.00001, global_step, decay_steps=2000, decay_rate=0.97, staircase=True)
    train_op = tf.train.AdamOptimizer(learning_rate=rate).minimize(loss_all, global_step=global_step)
    
    tf.summary.scalar('loss_yaw',loss_yaw)
    tf.summary.scalar('loss_pitch',loss_pitch)
    tf.summary.scalar('loss_roll',loss_roll)
    tf.summary.scalar('loss_all',loss_all)
    merged_summary_op = tf.summary.merge_all()
    
    init = tf.global_variables_initializer()
    
    config = tf.ConfigProto(allow_soft_placement = True)
    config.gpu_options.allow_growth = True
    
    
    
    with tf.Session(config=config) as sess:
        sess.run(init)
        
        if not os.path.exists(args.log_dir):
             os.makedirs(args.log_dir)
        if not os.path.exists(args.checkpoint_dir):
            os.makedirs(args.checkpoint_dir)
        
        ckpt = tf.train.latest_checkpoint(args.checkpoint_dir)
        if (ckpt):
            tf.logging.info('restore the trained model')
            saver = tf.train.Saver(max_to_keep=5)
            saver.restore(sess,ckpt)
        else:
            tf.logging.info('load the pre-trained model')
            checkpoint_exclude_scopes = 'Logits'
            #exclusions = None
            
            if checkpoint_exclude_scopes:
                exclusions = [scope.strip() for scope in checkpoint_exclude_scopes.split(',')]
                
            variables_to_restore = []
            for var in slim.get_model_variables():
                print(var)
                for exclusion in exclusions:
                    if var.op.name.startswith(exclusion):
                        break
                    else:
                        variables_to_restore.append(var)
            
            saver_restore = tf.train.Saver(variables_to_restore)
            saver = tf.train.Saver(max_to_keep=5)
            saver_restore.restore(sess, os.path.join(args.pretrained_path,'resnet_v1_50.ckpt'))
        
        train_writer = tf.summary.FileWriter(args.log_dir, sess.graph)
        dataset = datasets.Pose_300W_LP(data_dir='D:/300W_LP', filename_path=args.filename_list,
                                        batch_size=args.batch_size,
                                        image_size=args.image_size)
        
        for epoch in range(args.num_epochs):
            for i in range(dataset.length//args.batch_size):
                batch_images, batch_labels, batch_cont_labels = dataset.get()
                train_dict = {images: batch_images, 
                              labels: batch_labels,
                              is_training: True,
                              keep_prob: 0.5,
                              cont_labels: batch_cont_labels}
                _, loss, yaw_loss, pitch_loss, roll_loss, train_summary, step = sess.run([train_op,
                        loss_all, loss_yaw, loss_pitch,loss_roll, merged_summary_op, global_step],feed_dict = train_dict)
                train_writer.add_summary(train_summary,step)
                
                if step % 100==0:
                    tf.logging.info('the epoch %d: the loss of the step %d is: total_loss:%f, \
            loss_yaw:%f, loss_pitch:%f , loss_roll:%f'%(epoch, step, loss, yaw_loss, pitch_loss, roll_loss))
                
                if step % 500==0:
                    tf.logging.info('the epoch:%d, save the model for step %d'%(epoch,step))
                    saver.save(sess, os.path.join(args.checkpoint_dir,'model'), global_step=tf.cast(step*epoch, tf.int32))
                    
        tf.logging.info('==================Train Finished================')
    # Load snapshot
    saved_state_dict = torch.load(snapshot_path)
    model.load_state_dict(saved_state_dict)

    print('Loading data.')

    transformations = transforms.Compose([
        transforms.Scale(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])  # rgb模式

    if args.dataset == 'Pose_300W_LP':
        pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list,
                                             transformations)
    elif args.dataset == 'Pose_300W_LP_random_ds':
        pose_dataset = datasets.Pose_300W_LP_random_ds(args.data_dir,
                                                       args.filename_list,
                                                       transformations)
    elif args.dataset == 'AFLW2000':
        pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list,
                                         transformations)
    elif args.dataset == 'AFLW2000_ds':
        pose_dataset = datasets.AFLW2000_ds(args.data_dir, args.filename_list,
                                            transformations)
    elif args.dataset == 'BIWI':
        pose_dataset = datasets.BIWI(args.data_dir, args.filename_list,
                                     transformations)
    elif args.dataset == 'AFLW':
        pose_dataset = datasets.AFLW(args.data_dir, args.filename_list,
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
    else:
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))




def create_example(img, binned_pose,cont_labels):
    example = tf.train.Example(features=tf.train.Features(feature={
        'img_raw': tf_bytes_feature(img.tobytes() ),
        'img_width':tf_int_feature(img.width),
        'img_height':tf_int_feature(img.height),
        'binned_pose': tf_int_feature(binned_pose),
        'cont_labels': tf_float_feature(cont_labels),
    }))
    return example



if __name__ == '__main__':

    total_img_cnt = 0
    with tf.python_io.TFRecordWriter("tfrecord_dataset/train.tfrecords") as writer:
        testdataset = datasets.Pose_300W_LP()
        for img, binned_labels, cont_labels, imgpath in testdataset.generate():
            total_img_cnt+=1
            img = img.resize((112,112))
            tf_example = create_example(img,binned_labels.tolist(),cont_labels.tolist())
            writer.write(tf_example.SerializeToString())
    print("total img %d"%total_img_cnt)
Beispiel #6
0
    if args.snapshot == 'from_scratch' or args.snapshot == '':
        print "Learning from scratch"
    else:
        print "Loading from snapshot"
        saved_state_dict = torch.load(args.snapshot)
        load_filtered_state_dict(model, saved_state_dict)

    print 'Loading data.'

    transformations = transforms.Compose([transforms.Scale(240),
    transforms.RandomCrop(224), transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

    if args.dataset == 'Pose_300W_LP':
        pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations, bin_width_degrees)
    elif args.dataset == 'Pose_300W_LP_random_ds':
        pose_dataset = datasets.Pose_300W_LP_random_ds(args.data_dir, args.filename_list, transformations, bin_width_degrees)
    elif args.dataset == 'Synhead':
        pose_dataset = datasets.Synhead(args.data_dir, args.filename_list, transformations, bin_width_degrees)
    elif args.dataset == 'AFLW2000':
        pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list, transformations, bin_width_degrees)
    elif args.dataset == 'BIWI':
        pose_dataset = datasets.BIWI(args.data_dir, args.filename_list, transformations, bin_width_degrees)
    elif args.dataset == 'AFLW':
        pose_dataset = datasets.AFLW(args.data_dir, args.filename_list, transformations, bin_width_degrees)
    elif args.dataset == 'AFLW_aug':
        pose_dataset = datasets.AFLW_aug(args.data_dir, args.filename_list, transformations, bin_width_degrees)
    elif args.dataset == 'AFW':
        pose_dataset = datasets.AFW(args.data_dir, args.filename_list, transformations, bin_width_degrees)
    else:
         model_zoo.load_url(
             'https://download.pytorch.org/models/resnet50-19c8e357.pth'))
 else:
     saved_state_dict = torch.load(args.snapshot)
     model.load_state_dict(saved_state_dict)
 print('Loading data.')
 transformations = transforms.Compose([
     transforms.Scale(240),
     transforms.RandomCrop(224),
     transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225])
 ])
 transform = transforms.ToTensor()
 pose_dataset = datasets.Pose_300W_LP(
     '/home/leechanhyuk/Downloads/NEW_IMAGE',
     '/home/leechanhyuk/PycharmProjects/tensorflow1/file_name_list.txt',
     transform)
 train_loader = torch.utils.data.DataLoader(
     dataset=pose_dataset,
     batch_size=20,  # 원래는 arg_parse로 받는 수 였음(수정함)
     shuffle=True,
     num_workers=2)
 model.cuda(gpu)
 criterion = nn.CrossEntropyLoss().cuda(gpu)
 reg_criterion = nn.MSELoss().cuda(gpu)
 # Regression loss coefficient
 alpha = args.alpha
 softmax = nn.Softmax().cuda(gpu)
 idx_tensor = [idx for idx in range(67)]
 idx_tensor = Variable(torch.FloatTensor(idx_tensor)).cuda(gpu)
 optimizer = torch.optim.Adam([{
    else:
        saved_state_dict = torch.load(args.snapshot)
        model.load_state_dict(saved_state_dict)

    print 'Loading data.'

    transformations = transforms.Compose([
        transforms.Resize(240),
        transforms.RandomCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    if args.dataset == 'Pose_300W_LP':
        pose_dataset = datasets.Pose_300W_LP(args.data_dir, transformations)
    elif args.dataset == 'Pose_300W_LP_random_ds':
        pose_dataset = datasets.Pose_300W_LP_random_ds(args.data_dir,
                                                       args.filename_list,
                                                       transformations)
    elif args.dataset == 'Synhead':
        pose_dataset = datasets.Synhead(args.data_dir, args.filename_list,
                                        transformations)
    elif args.dataset == 'AFLW2000':
        pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list,
                                         transformations)
    elif args.dataset == 'BIWI':
        pose_dataset = datasets.BIWI(args.data_dir, args.filename_list,
                                     transformations)
    elif args.dataset == 'AFLW':
        pose_dataset = datasets.AFLW(args.data_dir, args.filename_list,
Beispiel #9
0
    def train(self):
        self.labels = tf.placeholder(dtype=tf.int32,
                                     shape=[None, 3],
                                     name='cls_label')
        self.cont_labels = tf.placeholder(dtype=tf.float32,
                                          shape=[None, 3],
                                          name='cont_labels')

        # Binned labels
        self.label_yaw = self.labels[:, 0]
        self.label_pitch = self.labels[:, 1]
        self.label_roll = self.labels[:, 2]

        # Continuous labels
        self.label_yaw_cont = self.cont_labels[:, 0]
        self.label_pitch_cont = self.cont_labels[:, 1]
        self.label_roll_cont = self.cont_labels[:, 2]

        # Cross entropy loss
        loss_yaw = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=tf.cast(self.label_yaw, tf.int64), logits=self.yaw)
        loss_pitch = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=tf.cast(self.label_pitch, tf.int64), logits=self.pitch)
        loss_roll = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=tf.cast(self.label_roll, tf.int64), logits=self.roll)

        self.loss_yaw_ce = tf.reduce_mean(loss_yaw)
        self.loss_pitch_ce = tf.reduce_mean(loss_pitch)
        self.loss_roll_ce = tf.reduce_mean(loss_roll)

        # MSE loss
        idx_tensor = [idx for idx in range(66)]
        idx_tensor = tf.convert_to_tensor(idx_tensor, dtype=tf.float32)

        self.loss_reg_yaw = tf.reduce_mean(
            tf.square(self.yaw_predicted - self.label_yaw_cont))
        self.loss_reg_pitch = tf.reduce_mean(
            tf.square(self.pitch_predicted - self.label_pitch_cont))
        self.loss_reg_roll = tf.reduce_mean(
            tf.square(self.roll_predicted - self.label_roll_cont))

        # Total loss
        self.loss_yaw = self.args.ce * self.loss_yaw_ce + self.args.alpha * self.loss_reg_yaw
        self.loss_pitch = self.args.ce * self.loss_pitch_ce + self.args.alpha * self.loss_reg_pitch
        self.loss_roll = self.args.ce * self.loss_roll_ce + self.args.alpha * self.loss_reg_roll
        self.loss_yaw = self.args.ce * self.loss_yaw_ce + self.args.alpha * self.loss_reg_yaw
        self.loss_pitch = self.args.ce * self.loss_pitch_ce + self.args.alpha * self.loss_reg_pitch
        self.loss_roll = self.args.ce * self.loss_roll_ce + self.args.alpha * self.loss_reg_roll

        self.loss_all = self.loss_yaw + self.loss_pitch + self.loss_roll

        global_step = tf.get_variable("step", [],
                                      initializer=tf.constant_initializer(0.0),
                                      trainable=False)
        rate = tf.train.exponential_decay(0.00001,
                                          global_step,
                                          decay_steps=2000,
                                          decay_rate=0.97,
                                          staircase=True)
        train_op = tf.train.AdamOptimizer(learning_rate=rate).minimize(
            self.loss_all, global_step=global_step)

        tf.summary.scalar('loss_mse_yaw', self.loss_reg_yaw)
        tf.summary.scalar('loss_mse_pitch', self.loss_reg_pitch)
        tf.summary.scalar('loss_mse_roll', self.loss_reg_roll)
        tf.summary.scalar('loss_ce_yaw', self.loss_yaw_ce)
        tf.summary.scalar('loss_ce_pitch', self.loss_pitch_ce)
        tf.summary.scalar('loss_ce_roll', self.loss_roll_ce)
        tf.summary.scalar('loss_yaw', self.loss_yaw)
        tf.summary.scalar('loss_pitch', self.loss_pitch)
        tf.summary.scalar('loss_roll', self.loss_roll)
        tf.summary.scalar('loss_all', self.loss_all)
        merged_summary_op = tf.summary.merge_all()

        init = tf.global_variables_initializer()

        self.sess.run(init)

        if not os.path.exists(self.args.log_dir):
            os.makedirs(self.args.log_dir)
        if not os.path.exists(self.args.checkpoint_dir):
            os.makedirs(self.args.checkpoint_dir)

        ckpt = tf.train.latest_checkpoint(self.args.checkpoint_dir)
        if (ckpt):
            tf.logging.info('restore the trained model')
            saver = tf.train.Saver(max_to_keep=5)
            saver.restore(self.sess, ckpt)
        else:
            print('[ ] load resnet model ....')
            tf.logging.info('load the pre-trained model')
            checkpoint_exclude_scopes = 'Logits'
            #exclusions = None

            if checkpoint_exclude_scopes:
                exclusions = [
                    scope.strip()
                    for scope in checkpoint_exclude_scopes.split(',')
                ]

            variables_to_restore = []
            for var in slim.get_model_variables():
                print(var)
                for exclusion in exclusions:
                    if var.op.name.startswith(exclusion):
                        break
                    else:
                        variables_to_restore.append(var)

            saver_restore = tf.train.Saver(variables_to_restore)
            saver = tf.train.Saver(max_to_keep=5)
            saver_restore.restore(
                self.sess,
                os.path.join(self.args.pretrained_path, 'resnet_v1_50.ckpt'))
            print('[*] finished loading resnet model ....')

        train_writer = tf.summary.FileWriter(self.args.log_dir,
                                             self.sess.graph)
        dataset = datasets.Pose_300W_LP(data_dir='./300W_LP',
                                        filename_path=self.args.filename_list,
                                        batch_size=self.args.batch_size,
                                        image_size=self.args.image_size)

        for epoch in range(self.args.num_epochs):
            for i in range(dataset.length // self.args.batch_size):
                batch_images, batch_labels, batch_cont_labels = dataset.get()
                train_dict = {
                    self.images: batch_images,
                    self.labels: batch_labels,
                    self.is_training: True,
                    self.keep_prob: 0.5,
                    self.cont_labels: batch_cont_labels
                }
                _, loss, yaw_loss, pitch_loss, roll_loss, train_summary, step = self.sess.run(
                    [
                        train_op, self.loss_all, self.loss_yaw,
                        self.loss_pitch, self.loss_roll, merged_summary_op,
                        global_step
                    ],
                    feed_dict=train_dict)

                train_writer.add_summary(train_summary, step)

                # inference of predicted value
                if step % 100 == 0:
                    yaw, pitch, roll = self.sess.run([
                        self.yaw_predicted, self.pitch_predicted,
                        self.roll_predicted
                    ],
                                                     feed_dict=train_dict)
                    print('GT: {}\nP :{},{},{}'.format(batch_cont_labels[0],
                                                       yaw[0], pitch[0],
                                                       roll[0]))

                if step % 100 == 0:
                    tf.logging.info(
                        'the epoch %d: the loss of the step %d is: total_loss:%f\n loss_yaw:%f\n loss_pitch:%f\n loss_roll:%f'
                        % (epoch, step, loss, yaw_loss, pitch_loss, roll_loss))
                    print(
                        'the epoch %d: the loss of the step %d is: total_loss:%f\n loss_yaw:%f\n loss_pitch:%f\n loss_roll:%f'
                        % (epoch, step, loss, yaw_loss, pitch_loss, roll_loss))

                if step % 500 == 0:
                    tf.logging.info(
                        'the epoch:%d, save the model for step %d' %
                        (epoch, step))
                    print('the epoch:%d, save the model for step %d' %
                          (epoch, step))
                    saver.save(self.sess,
                               os.path.join(self.args.checkpoint_dir, 'model'),
                               global_step=tf.cast(step * epoch, tf.int32))

        tf.logging.info('==================Train Finished================')
        print('==================Train Finished================')
    if args.snapshot == '':
        print("lodel pretrain model")
        load_filtered_state_dict(model, model_zoo.load_url('https://download.pytorch.org/models/resnet18-5c106cde.pth'))
    else:
        saved_state_dict = torch.load(args.snapshot)
        model.load_state_dict(saved_state_dict)
    '''

    transformations = transforms.Compose([
        transforms.Resize(128),
        transforms.RandomCrop(112),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4, 0.4, 0.4], std=[0.2, 0.2, 0.2])
    ])

    pose_dataset = datasets.Pose_300W_LP(transformations)

    train_loader = torch.utils.data.DataLoader(dataset=pose_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=2)

    model.cuda(gpu)
    criterion = nn.CrossEntropyLoss().cuda(gpu)
    reg_criterion = nn.MSELoss().cuda(gpu)
    # Regression loss coefficient
    alpha = args.alpha

    softmax = nn.Softmax(dim=1).cuda(gpu)
    idx_tensor = [idx for idx in range(67)]
    idx_tensor = Variable(torch.FloatTensor(idx_tensor)).cuda(gpu)
    print('Loading data.')
    transformations = transforms.Compose([
        transforms.Scale(240),
        transforms.RandomCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    if args.dataset == 'Pose_300W_LP':
        train_filename_list = os.path.join(args.filename_list,
                                           'train_filename_all.npy')
        val_filename_list = os.path.join(args.filename_list,
                                         'val_filename_all.npy')
        pose_dataset_train = datasets.Pose_300W_LP(args.data_dir, num_bins,
                                                   train_filename_list,
                                                   transformations, args.debug)
        pose_dataset_val = datasets.Pose_300W_LP(args.data_dir, num_bins,
                                                 val_filename_list,
                                                 transformations, args.debug)
    elif args.dataset == 'AFLW2000':
        test_filename_list = os.path.join(args.filename_list,
                                          'test_filename.npy')
        pose_dataset = datasets.AFLW2000(args.data_dir, num_bins,
                                         args.filename_list, transformations)
    else:
        print('Error: not a valid dataset name')
        sys.exit()

    train_loader = torch.utils.data.DataLoader(dataset=pose_dataset_train,
                                               batch_size=batch_size,