Exemplo n.º 1
0
class face_learner(object):
    def __init__(self, conf, inference=False, transfer=0, ext='final'):
        pprint.pprint(conf)
        self.conf = conf
        if conf.use_mobilfacenet:
            self.model = MobileFaceNet(conf.embedding_size).to(conf.device)
            print('MobileFaceNet model generated')
        else:
            self.model = Backbone(conf.net_depth, conf.drop_ratio,
                                  conf.net_mode).to(conf.device)
            print('{}_{} model generated'.format(conf.net_mode,
                                                 conf.net_depth))

        if not inference:
            self.milestones = conf.milestones
            self.loader, self.class_num = get_train_loader(conf)

            tmp_idx = ext.rfind('_')  # find the last '_' to replace it by '/'
            self.ext = '/' + ext[:tmp_idx] + '/' + ext[tmp_idx + 1:]
            self.writer = SummaryWriter(str(conf.log_path) + self.ext)
            self.step = 0
            self.head = Arcface(embedding_size=conf.embedding_size,
                                classnum=self.class_num).to(conf.device)

            print('two model heads generated')

            paras_only_bn, paras_wo_bn = separate_bn_paras(self.model)

            self.optimizer = optim.Adam(
                list(self.model.parameters()) + list(self.head.parameters()),
                conf.lr)
            print(self.optimizer)
            # self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=40, verbose=True)

            print('optimizers generated')
            self.save_freq = len(self.loader) // 5  #//5 # originally, 100
            self.evaluate_every = len(self.loader)  #//5 # originally, 10
            self.save_every = len(self.loader)  #//2 # originally, 5
            # self.agedb_30, self.cfp_fp, self.lfw, self.agedb_30_issame, self.cfp_fp_issame, self.lfw_issame = get_val_data(self.loader.dataset.root.parent)
            # self.val_112, self.val_112_issame = get_val_pair(self.loader.dataset.root.parent, 'val_112')
        else:
            self.threshold = conf.threshold

        self.train_losses = []
        self.train_counter = []
        self.test_losses = []
        self.test_accuracy = []
        self.test_counter = []

    def save_state(self,
                   conf,
                   accuracy,
                   to_save_folder=False,
                   extra=None,
                   model_only=False):
        if to_save_folder:
            save_path = conf.save_path
        else:
            save_path = conf.model_path
        torch.save(
            self.model.state_dict(),
            save_path / ('model_{}_accuracy:{:0.2f}_step:{}_{}.pth'.format(
                get_time(), accuracy, self.step, extra)))
        if not model_only:
            torch.save(
                self.head.state_dict(),
                save_path / ('head_{}_accuracy:{:0.2f}_step:{}_{}.pth'.format(
                    get_time(), accuracy, self.step, extra)))
            torch.save(
                self.optimizer.state_dict(), save_path /
                ('optimizer_{}_accuracy:{:0.2f}_step:{}_{}.pth'.format(
                    get_time(), accuracy, self.step, extra)))

    def load_state(self,
                   conf,
                   fixed_str,
                   from_save_folder=False,
                   model_only=False):
        if from_save_folder:
            save_path = conf.save_path
        else:
            save_path = conf.model_path
        self.model.load_state_dict(
            torch.load(save_path / 'model_{}'.format(fixed_str),
                       map_location=conf.device))
        if not model_only:
            self.head.load_state_dict(
                torch.load(save_path / 'head_{}'.format(fixed_str)))
            self.optimizer.load_state_dict(
                torch.load(save_path / 'optimizer_{}'.format(fixed_str)))

    def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor):
        self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy,
                               self.step)
        self.writer.add_scalar('{}_best_threshold'.format(db_name),
                               best_threshold, self.step)
        self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor,
                              self.step)
        # self.writer.add_scalar('{}_val:true accept ratio'.format(db_name), val, self.step)
        # self.writer.add_scalar('{}_val_std'.format(db_name), val_std, self.step)
        # self.writer.add_scalar('{}_far:False Acceptance Ratio'.format(db_name), far, self.step)

    def evaluate(self, conf, carray, issame, nrof_folds=5, tta=False):
        self.model.eval()
        idx = 0
        embeddings = np.zeros([len(carray), conf.embedding_size])
        with torch.no_grad():
            while idx + conf.batch_size <= len(carray):
                batch = torch.tensor(carray[idx:idx + conf.batch_size])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.to(conf.device)) + self.model(
                        fliped.to(conf.device))
                    embeddings[idx:idx + conf.batch_size] = l2_norm(emb_batch)
                else:
                    embeddings[idx:idx + conf.batch_size] = self.model(
                        batch.to(conf.device)).cpu()
                idx += conf.batch_size
            if idx < len(carray):
                batch = torch.tensor(carray[idx:])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.to(conf.device)) + self.model(
                        fliped.to(conf.device))
                    embeddings[idx:] = l2_norm(emb_batch)
                else:
                    embeddings[idx:] = self.model(batch.to(conf.device)).cpu()
        tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame,
                                                       nrof_folds)
        buf = gen_plot(fpr, tpr)
        roc_curve = Image.open(buf)
        roc_curve_tensor = trans.ToTensor()(roc_curve)
        return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor

    def find_lr(self,
                conf,
                init_value=1e-8,
                final_value=10.,
                beta=0.98,
                bloding_scale=3.,
                num=None):
        if not num:
            num = len(self.loader)
        mult = (final_value / init_value)**(1 / num)
        lr = init_value
        for params in self.optimizer.param_groups:
            params['lr'] = lr
        self.model.train()
        avg_loss = 0.
        best_loss = 0.
        batch_num = 0
        losses = []
        log_lrs = []
        for i, (imgs, labels) in enumerate(
                self.loader):  #tqdm(enumerate(self.loader), total=num):

            imgs = imgs.to(conf.device)
            labels = labels.to(conf.device)
            batch_num += 1

            self.optimizer.zero_grad()

            embeddings = self.model(imgs)
            thetas = self.head(embeddings, labels)
            loss = conf.ce_loss(thetas, labels)

            #Compute the smoothed loss
            avg_loss = beta * avg_loss + (1 - beta) * loss.item()
            self.writer.add_scalar('avg_loss', avg_loss, batch_num)
            smoothed_loss = avg_loss / (1 - beta**batch_num)
            self.writer.add_scalar('smoothed_loss', smoothed_loss, batch_num)
            #Stop if the loss is exploding
            if batch_num > 1 and smoothed_loss > bloding_scale * best_loss:
                print('exited with best_loss at {}'.format(best_loss))
                plt.plot(log_lrs[10:-5], losses[10:-5])
                return log_lrs, losses
            #Record the best loss
            if smoothed_loss < best_loss or batch_num == 1:
                best_loss = smoothed_loss
            #Store the values
            losses.append(smoothed_loss)
            log_lrs.append(math.log10(lr))
            self.writer.add_scalar('log_lr', math.log10(lr), batch_num)
            #Do the SGD step
            #Update the lr for the next step

            loss.backward()
            self.optimizer.step()

            lr *= mult
            for params in self.optimizer.param_groups:
                params['lr'] = lr
            if batch_num > num:
                plt.plot(log_lrs[10:-5], losses[10:-5])
                return log_lrs, losses

    def train(self, conf, epochs):
        self.model.train()
        running_loss = 0.
        for e in range(epochs):
            print('epoch {} started'.format(e))
            if e == self.milestones[0]:
                self.schedule_lr()
            if e == self.milestones[1]:
                self.schedule_lr()
            if e == self.milestones[2]:
                self.schedule_lr()
            for imgs, labels in iter(self.loader):  #tqdm(iter(self.loader)):
                imgs = imgs.to(conf.device)
                labels = labels.to(conf.device)
                self.optimizer.zero_grad()
                embeddings = self.model(imgs)
                thetas = self.head(embeddings, labels)
                loss = conf.ce_loss(thetas, labels)
                loss.backward()
                running_loss += loss.item()
                self.optimizer.step()

                if self.step % self.save_freq == 0 and self.step != 0:
                    self.train_losses.append(loss.item())
                    self.train_counter.append(self.step)

                self.step += 1

            self.save_loss()

        # self.save_state(conf, accuracy, to_save_folder=True, extra=self.ext, model_only=True)

    def schedule_lr(self):
        for params in self.optimizer.param_groups:
            params['lr'] /= 10
        print(self.optimizer)

    def infer(self, conf, faces, target_embs, tta=False):
        '''
        faces : list of PIL Image
        target_embs : [n, 512] computed embeddings of faces in facebank
        names : recorded names of faces in facebank
        tta : test time augmentation (hfilp, that's all)
        '''
        embs = []
        for img in faces:
            if tta:
                mirror = trans.functional.hflip(img)
                emb = self.model(
                    conf.test_transform(img).to(conf.device).unsqueeze(0))
                emb_mirror = self.model(
                    conf.test_transform(mirror).to(conf.device).unsqueeze(0))
                embs.append(l2_norm(emb + emb_mirror))
            else:
                embs.append(
                    self.model(
                        conf.test_transform(img).to(conf.device).unsqueeze(0)))
        source_embs = torch.cat(embs)

        diff = source_embs.unsqueeze(-1) - target_embs.transpose(
            1, 0).unsqueeze(0)
        dist = torch.sum(torch.pow(diff, 2), dim=1)
        minimum, min_idx = torch.min(dist, dim=1)
        min_idx[minimum > self.threshold] = -1  # if no match, set idx to -1
        return min_idx, minimum

    def binfer(self, conf, faces, target_embs, tta=False):
        '''
        return raw scores for every class 
        faces : list of PIL Image
        target_embs : [n, 512] computed embeddings of faces in facebank
        names : recorded names of faces in facebank
        tta : test time augmentation (hfilp, that's all)
        '''
        self.model.eval()
        self.plot_result()
        embs = []
        for img in faces:
            if tta:
                mirror = trans.functional.hflip(img)
                emb = self.model(
                    conf.test_transform(img).to(conf.device).unsqueeze(0))
                emb_mirror = self.model(
                    conf.test_transform(mirror).to(conf.device).unsqueeze(0))
                embs.append(l2_norm(emb + emb_mirror))
            else:
                embs.append(
                    self.model(
                        conf.test_transform(img).to(conf.device).unsqueeze(0)))
        source_embs = torch.cat(embs)

        diff = source_embs.unsqueeze(-1) - target_embs.transpose(
            1, 0).unsqueeze(0)
        dist = torch.sum(torch.pow(diff, 2), dim=1)
        # print(dist)
        return dist.detach().cpu().numpy()
        # minimum, min_idx = torch.min(dist, dim=1)
        # min_idx[minimum > self.threshold] = -1 # if no match, set idx to -1
        # return min_idx, minimum

    def save_loss(self):
        if not os.path.exists(self.conf.stored_result_dir):
            os.mkdir(self.conf.stored_result_dir)

        result = dict()
        result["train_losses"] = np.asarray(self.train_losses)
        result["train_counter"] = np.asarray(self.train_counter)
        result['test_accuracy'] = np.asarray(self.test_accuracy)
        result['test_losses'] = np.asarray(self.test_losses)
        result["test_counter"] = np.asarray(self.test_counter)

        with open(os.path.join(self.conf.stored_result_dir, "result_log.p"),
                  'wb') as fp:
            pickle.dump(result, fp)

    def plot_result(self):
        result_log_path = os.path.join(self.conf.stored_result_dir,
                                       "result_log.p")
        with open(result_log_path, 'rb') as f:
            result_dict = pickle.load(f)

        train_losses = result_dict['train_losses']
        train_counter = result_dict['train_counter']
        test_losses = result_dict['test_losses']
        test_counter = result_dict['test_counter']
        test_accuracy = result_dict['test_accuracy']

        fig1 = plt.figure(figsize=(12, 8))
        ax1 = fig1.add_subplot(111)
        ax1.plot(train_counter, train_losses, 'b', label='Train_loss')
        ax1.legend('Train_losses')
        plt.savefig(os.path.join(self.conf.stored_result_dir,
                                 "train_loss.png"))
        """
Exemplo n.º 2
0
class face_learner(object):
    def __init__(self, conf):
        print(conf)
        self.model = ResNet()
        self.model.cuda()
        if conf.initial:
            self.model.load_state_dict(torch.load("models/"+conf.model))
            print('Load model_ir_se101.pth')
        self.milestones = conf.milestones
        self.loader, self.class_num = get_train_loader(conf)
        self.total_class = 16520
        self.data_num = 285356
        self.writer = SummaryWriter(conf.log_path)
        self.step = 0
        self.paras_only_bn, self.paras_wo_bn = separate_bn_paras(self.model)

        if conf.meta:
            self.head = Arcface(embedding_size=conf.embedding_size, classnum=self.total_class)
            self.head.cuda()
            if conf.initial:
                self.head.load_state_dict(torch.load("models/head_op.pth"))
                print('Load head_op.pth')
            self.optimizer = RAdam([
                {'params': self.paras_wo_bn + [self.head.kernel], 'weight_decay': 5e-4},
                {'params': self.paras_only_bn}
            ], lr=conf.lr)
            self.meta_optimizer = RAdam([
                {'params': self.paras_wo_bn + [self.head.kernel], 'weight_decay': 5e-4},
                {'params': self.paras_only_bn}
            ], lr=conf.lr)
            self.head.train()
        else:
            self.head = dict()
            self.optimizer = dict()
            for race in races:
                self.head[race] = Arcface(embedding_size=conf.embedding_size, classnum=self.class_num[race])
                self.head[race].cuda()
                if conf.initial:
                    self.head[race].load_state_dict(torch.load("models/head_op_{}.pth".format(race)))
                    print('Load head_op_{}.pth'.format(race))
                self.optimizer[race] = RAdam([
                    {'params': self.paras_wo_bn + [self.head[race].kernel], 'weight_decay': 5e-4},
                    {'params': self.paras_only_bn}
                ], lr=conf.lr, betas=(0.5, 0.999))
                self.head[race].train()
            # self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=40, verbose=True)

        self.board_loss_every = min(len(self.loader[race]) for race in races) // 10
        self.evaluate_every = self.data_num // 5
        self.save_every = self.data_num // 2
        self.eval, self.eval_issame = get_val_data(conf)

    def save_state(self, conf, accuracy, extra=None, model_only=False, race='All'):
        save_path = 'models/'
        torch.save(
            self.model.state_dict(), save_path +
                                     'model_{}_accuracy-{}_step-{}_{}_{}.pth'.format(get_time(), accuracy, self.step,
                                                                                  extra, race))
        if not model_only:
            if conf.meta:
                torch.save(
                    self.head.state_dict(), save_path +
                                        'head_{}_accuracy-{}_step-{}_{}_{}.pth'.format(get_time(), accuracy, self.step,
                                                                                    extra, race))
                #torch.save(
                #    self.optimizer.state_dict(), save_path +
                #                             'optimizer_{}_accuracy-{}_step-{}_{}_{}.pth'.format(get_time(), accuracy,
                #                                                                              self.step, extra, race))
            else:
                torch.save(
                    self.head[race].state_dict(), save_path +
                                            'head_{}_accuracy-{}_step-{}_{}_{}.pth'.format(get_time(), accuracy,
                                                                                           self.step,
                                                                                           extra, race))
                #torch.save(
                #    self.optimizer[race].state_dict(), save_path +
                 #                                'optimizer_{}_accuracy-{}_step-{}_{}_{}.pth'.format(get_time(),
                #                                                                                     accuracy,
                #                                                                                     self.step, extra,
                #                                                                                     race))

    def load_state(self, conf, fixed_str, model_only=False):
        save_path = 'models/'
        self.model.load_state_dict(torch.load(save_path + conf.model))
        if not model_only:
            self.head.load_state_dict(torch.load(save_path + conf.head))
            self.optimizer.load_state_dict(torch.load(save_path + conf.optim))

    def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor):
        self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy, self.step)
        self.writer.add_scalar('{}_best_threshold'.format(db_name), best_threshold, self.step)
        self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor, self.step)

        # self.writer.add_scalar('{}_val:true accept ratio'.format(db_name), val, self.step)
        # self.writer.add_scalar('{}_val_std'.format(db_name), val_std, self.step)
        # self.writer.add_scalar('{}_far:False Acceptance Ratio'.format(db_name), far, self.step)

    def evaluate(self, conf, carray, issame, nrof_folds=5, tta=False):
        self.model.eval()
        idx = 0
        entry_num = carray.size()[0]
        embeddings = np.zeros([entry_num, conf.embedding_size])
        with torch.no_grad():
            while idx + conf.batch_size <= entry_num:
                batch = carray[idx:idx + conf.batch_size]
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.cuda()) + self.model(fliped.cuda())
                    embeddings[idx:idx + conf.batch_size] = l2_norm(emb_batch).cpu().detach().numpy()
                else:
                    embeddings[idx:idx + conf.batch_size] = self.model(batch.cuda()).cpu().detach().numpy()
                idx += conf.batch_size
            if idx < entry_num:
                batch = carray[idx:]
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.cuda()) + self.model(fliped.cuda())
                    embeddings[idx:] = l2_norm(emb_batch).cpu().detach().numpy()
                else:
                    embeddings[idx:] = self.model(batch.cuda()).cpu().detach().numpy()
        tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame, nrof_folds)
        buf = gen_plot(fpr, tpr)
        roc_curve = Image.open(buf)
        roc_curve_tensor = trans.ToTensor()(roc_curve)
        return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor

    def train_finetuning(self, conf, epochs, race):
        self.model.train()
        running_loss = 0.
        for e in range(epochs):
            print('epoch {} started'.format(e))
            '''
            if e == self.milestones[0]:
                for ra in races:
                    for params in self.optimizer[ra].param_groups:
                        params['lr'] /= 10
            if e == self.milestones[1]:
                for ra in races:
                    for params in self.optimizer[ra].param_groups:
                        params['lr'] /= 10
            if e == self.milestones[2]:
                for ra in races:
                    for params in self.optimizer[ra].param_groups:
                        params['lr'] /= 10
            '''
            for imgs, labels in tqdm(iter(self.loader[race])):
                imgs = imgs.cuda()
                labels = labels.cuda()
                self.optimizer[race].zero_grad()
                embeddings = self.model(imgs)
                thetas = self.head[race](embeddings, labels)
                loss = conf.ce_loss(thetas, labels)
                loss.backward()
                running_loss += loss.item()
                nn.utils.clip_grad_norm_(self.model.parameters(), conf.max_grad_norm)
                nn.utils.clip_grad_norm_(self.head[race].parameters(), conf.max_grad_norm)
                self.optimizer[race].step()

                if self.step % self.board_loss_every == 0 and self.step != 0:
                    loss_board = running_loss / self.board_loss_every
                    self.writer.add_scalar('train_loss', loss_board, self.step)
                    running_loss = 0.

                if self.step % (1 * len(self.loader[race])) == 0 and self.step != 0:
                    self.save_state(conf, 'None', race=race, model_only=True)

                self.step += 1

        self.save_state(conf, 'None', extra='final', race=race)
        torch.save(self.optimizer[race].state_dict(), 'models/optimizer_{}.pth'.format(race))

    def train_maml(self, conf, epochs):
        self.model.train()
        running_loss = 0.
        loader_iter = dict()
        for race in races:
            loader_iter[race] = iter(self.loader[race])
        for e in range(epochs):
            print('epoch {} started'.format(e))
            if e == self.milestones[0]:
                self.schedule_lr()
            if e == self.milestones[1]:
                self.schedule_lr()
            if e == self.milestones[2]:
                self.schedule_lr()
            for i in tqdm(range(self.data_num // conf.batch_size)):
                ra1, ra2 = random.sample(races, 2)
                try:
                    imgs1, labels1 = loader_iter[ra1].next()
                except StopIteration:
                    loader_iter[ra1] = iter(self.loader[ra1])
                    imgs1, labels1 = loader_iter[ra1].next()

                try:
                    imgs2, labels2 = loader_iter[ra2].next()
                except StopIteration:
                    loader_iter[ra2] = iter(self.loader[ra2])
                    imgs2, labels2 = loader_iter[ra2].next()

                ## save original weights to make the update
                weights_original_model = deepcopy(self.model.state_dict())
                weights_original_head = deepcopy(self.head.state_dict())

                # base learn
                imgs1 = imgs1.cuda()
                labels1 = labels1.cuda()
                self.optimizer.zero_grad()
                embeddings1 = self.model(imgs1)
                thetas1 = self.head(embeddings1, labels1)
                loss1 = conf.ce_loss(thetas1, labels1)
                loss1.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), conf.max_grad_norm)
                nn.utils.clip_grad_norm_(self.head.parameters(), conf.max_grad_norm)
                self.optimizer.step()

                # meta learn
                imgs2 = imgs2.cuda()
                labels2 = labels2.cuda()
                embeddings2 = self.model(imgs2)
                thetas2 = self.head(embeddings2, labels2)
                self.model.load_state_dict(weights_original_model)
                self.head.load_state_dict(weights_original_head)
                self.meta_optimizer.zero_grad()
                loss2 = conf.ce_loss(thetas2, labels2)
                loss2.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), conf.max_grad_norm)
                nn.utils.clip_grad_norm_(self.head.parameters(), conf.max_grad_norm)
                self.meta_optimizer.step()

                running_loss += loss2.item()

                if self.step % self.board_loss_every == 0 and self.step != 0:
                    loss_board = running_loss / self.board_loss_every
                    self.writer.add_scalar('train_loss', loss_board, self.step)
                    running_loss = 0.

                if self.step % self.evaluate_every == 0 and self.step != 0:
                    for race in races:
                        accuracy, best_threshold, roc_curve_tensor = self.evaluate(conf, self.eval[race], self.eval_issame[race])
                        self.board_val(race, accuracy, best_threshold, roc_curve_tensor)
                    self.model.train()

                if self.step % (self.data_num // conf.batch_size // 2) == 0 and self.step != 0:
                    self.save_state(conf, e)

                self.step += 1

        self.save_state(conf, epochs, extra='final')

    def train_meta_head(self, conf, epochs):
        self.model.train()
        running_loss = 0.
        optimizer = optim.SGD(self.head.parameters(), lr=conf.lr, momentum=conf.momentum)
        for e in range(epochs):
            print('epoch {} started'.format(e))
            if e == self.milestones[0]:
                self.schedule_lr()
            if e == self.milestones[1]:
                self.schedule_lr()
            if e == self.milestones[2]:
                self.schedule_lr()
            for race in races:
                for imgs, labels in tqdm(iter(self.loader[race])):
                    imgs = imgs.cuda()
                    labels = labels.cuda()
                    optimizer.zero_grad()
                    embeddings = self.model(imgs)
                    thetas = self.head(embeddings, labels)
                    loss = conf.ce_loss(thetas, labels)
                    loss.backward()
                    running_loss += loss.item()
                    optimizer.step()

                    if self.step % self.board_loss_every == 0 and self.step != 0:
                        loss_board = running_loss / self.board_loss_every
                        self.writer.add_scalar('train_loss', loss_board, self.step)
                        running_loss = 0.

                    self.step += 1

            torch.save(self.head.state_dict(), 'models/head_{}_meta_{}.pth'.format(get_time(), e))

    def train_race_head(self, conf, epochs, race):
        self.model.train()
        running_loss = 0.
        optimizer = optim.SGD(self.head[race].parameters(), lr=conf.lr, momentum=conf.momentum)
        for e in range(epochs):
            print('epoch {} started'.format(e))
            if e == self.milestones[0]:
                self.schedule_lr()
            if e == self.milestones[1]:
                self.schedule_lr()
            if e == self.milestones[2]:
                self.schedule_lr()
            for imgs, labels in tqdm(iter(self.loader[race])):
                imgs = imgs.cuda()
                labels = labels.cuda()
                optimizer.zero_grad()
                embeddings = self.model(imgs)
                thetas = self.head[race](embeddings, labels)
                loss = conf.ce_loss(thetas, labels)
                loss.backward()
                running_loss += loss.item()
                optimizer.step()

                if self.step % self.board_loss_every == 0 and self.step != 0:
                    loss_board = running_loss / self.board_loss_every
                    self.writer.add_scalar('train_loss', loss_board, self.step)
                    running_loss = 0.

                self.step += 1

        torch.save(self.head[race].state_dict(), 'models/head_{}_{}_{}.pth'.format(get_time(), race, epochs))

    def schedule_lr(self):
        for params in self.optimizer.param_groups:
            params['lr'] /= 10
        for params in self.meta_optimizer.param_groups:
            params['lr'] /= 10
        print(self.optimizer, self.meta_optimizer)