Example #1
0
def main():
    # parameters
    learning_rate = 0.01
    num_epochs = 50
    batch_size = 250
    feature_size = 2048
    test_index = 2

    # create the save log file
    print("Create the directory")
    if not os.path.exists("./save"):
        os.makedirs("./save")
    if not os.path.exists("./logfile"):
        os.makedirs("./logfile")
    if not os.path.exists("./logfile/MTL"):
        os.makedirs("./logfile/MTL")

    # load my Dataset
    type = ["infograph", "quickdraw", "sketch", "real", "test"]
    print("training set : %s ,%s, %s" % (type[0], type[1], type[3]))
    print("testing set : %s" % (type[2]))
    inf_train_dataset = Dataset.Dataset(mode="train", type=type[0])
    inf_train_loader = DataLoader(inf_train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=1)
    qdr_train_dataset = Dataset.Dataset(mode="train", type=type[1])
    qdr_train_loader = DataLoader(qdr_train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=1)
    skt_train_dataset = Dataset.Dataset(mode="train", type=type[2])
    skt_train_loader = DataLoader(skt_train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=1)
    rel_train_dataset = Dataset.Dataset(mode="train", type=type[3])
    rel_train_loader = DataLoader(rel_train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=1)

    test_dataset = Dataset.Dataset(mode="test", type=type[0])
    test_loader = DataLoader(test_dataset,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=1)

    print('the source dataset has %d size.' % (len(inf_train_dataset)))
    print('the target dataset has %d size.' % (len(test_dataset)))
    print('the batch_size is %d' % (batch_size))

    # Pre-train models
    encoder = Encoder()
    classifier = Classifier(feature_size)
    domain_classifier = Domain_classifier(feature_size, number_of_domain)

    # GPU enable
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    print('Device used:', device)
    if torch.cuda.is_available():
        encoder = encoder.to(device)
        domain_classifier = domain_classifier.to(device)
        classifier = classifier.to(device)

    # setup optimizer
    optimizer_encoder = optim.Adam(encoder.parameters(),
                                   weight_decay=1e-4,
                                   lr=learning_rate)
    optimizer_domain_classifier = optim.Adam(domain_classifier.parameters(),
                                             weight_decay=1e-4,
                                             lr=learning_rate)
    optimizer_classifier = optim.Adam(classifier.parameters(),
                                      weight_decay=1e-4,
                                      lr=learning_rate)

    print("Starting training...")

    for epoch in range(num_epochs):
        print("Epoch:", epoch + 1)

        train_loader = [
            inf_train_loader, qdr_train_loader, skt_train_loader,
            rel_train_loader
        ]
        domain_labels = torch.LongTensor([[0 for i in range(batch_size)],
                                          [1 for i in range(batch_size)],
                                          [2 for i in range(batch_size)],
                                          [3 for i in range(batch_size)]])

        mtl_criterion = nn.CrossEntropyLoss()
        moe_criterion = nn.CrossEntropyLoss()

        encoder.train()
        domain_classifier.train()
        classifier.train()

        epoch_D_loss = 0.0
        epoch_C_loss = 0.0
        sum_trg_acc = 0.0
        sum_label_acc = 0.0
        sum_test_acc = 0.0

        for index, (inf, qdr, skt, rel, test) in enumerate(
                zip(train_loader[0], train_loader[1], train_loader[2],
                    train_loader[3], test_loader)):

            optimizer_encoder.zero_grad()
            optimizer_classifier.zero_grad()
            optimizer_domain_classifier.zero_grad()

            # colculate the lambda_
            p = (index + len(train_loader[0]) * epoch) / (
                len(train_loader[0]) * num_epochs)
            lambda_ = 2.0 / (1. + np.exp(-10 * p)) - 1.0

            s1_imgs, s1_labels = inf
            s2_imgs, s2_labels = qdr
            s3_imgs, s3_labels = rel
            t1_imgs, _ = skt
            s1_imgs = Variable(s1_imgs).to(device)
            s1_labels = Variable(s1_labels).to(device)
            s2_imgs = Variable(s2_imgs).to(device)
            s2_labels = Variable(s2_labels).to(device)
            s3_imgs = Variable(s3_imgs).to(device)
            s3_labels = Variable(s3_labels).to(device)
            t1_imgs = Variable(t1_imgs).to(device)

            s1_feature = encoder(s1_imgs)
            #t1_feature = encoder(t1_imgs)

            # Testing
            test_imgs, test_labels = test
            test_imgs = Variable(test_imgs).to(device)
            test_labels = Variable(test_labels).to(device)
            test_feature = encoder(test_imgs)
            test_output = classifier(test_feature)
            test_preds = test_output.argmax(1).cpu()
            test_acc = np.mean((test_preds == test_labels.cpu()).numpy())

            # Classifier network
            s1_output = classifier(s1_feature)

            s1_preds = s1_output.argmax(1).cpu()

            s1_acc = np.mean((s1_preds == s1_labels.cpu()).numpy())
            s1_c_loss = mtl_criterion(s1_output, s1_labels)
            C_loss = s1_c_loss

            # Domain_classifier network with source domain
            #domain_labels = Variable(domain_labels).to(device)
            #s1_domain_output = domain_classifier(s1_feature,lambda_)

            #s1_domain_preds = s1_domain_output.argmax(1).cpu()
            #if index == 10:
            #    print(s1_domain_preds)
            #s1_domain_acc = np.mean((s1_domain_preds == 0).numpy())
            #print(s1_domain_output.shape)
            #print(s1_domain_output[0])
            #s1_d_loss = moe_criterion(s1_domain_output,domain_labels[0])
            #D_loss_src = s1_d_loss
            #print(D_loss_src.item())

            # Domain_classifier network with target domain
            #t1_domain_output = domain_classifier(t1_feature,lambda_)
            #t1_domain_preds = t1_domain_output.argmax(1).cpu()
            #t1_domain_acc = np.mean((t1_domain_preds == 3).numpy())
            #t1_d_loss = moe_criterion(t1_domain_output,domain_labels[3])

            #D_loss = D_loss_src + t1_d_loss
            loss = C_loss
            D_loss = 0
            #epoch_D_loss += D_loss.item()
            epoch_C_loss += C_loss.item()
            #sum_trg_acc += t1_domain_acc
            #D_src_acc = (s1_domain_acc + s2_domain_acc + s3_domain_acc)/3.

            loss.backward()
            optimizer_encoder.step()
            optimizer_classifier.step()
            optimizer_domain_classifier.step()
            if (index + 1) % 10 == 0:
                print(
                    'Iter [%d/%d] loss %.4f , D_loss %.4f ,Acc %.4f  ,Test Acc: %.4f'
                    % (index + 1, len(train_loader[0]), loss.item(), D_loss,
                       s1_acc, test_acc))

        test_acc = 0.
        test_loss = 0.
        encoder.eval()
        domain_classifier.eval()
        classifier.eval()

        for index, (imgs, labels) in enumerate(test_loader):
            output_list = []
            loss_mtl = []
            imgs = Variable(imgs).to(device)
            labels = Variable(labels).to(device)
            hidden = encoder(imgs)
            output = classifier(hidden)
            preds = output.argmax(1).cpu()
            s1_acc = np.mean((preds == labels.cpu()).numpy())
            """
            for sthi in classifiers:
                output = sthi(hidden)
                output_list.append(output.cpu())
                loss = mtl_criterion(output, labels)
                loss_mtl.append(loss)


            output = torch.FloatTensor(np.array(output_list).sum(0))
            preds = output.argmax(1).cpu()
            s1_preds = output_list[0].argmax(1).cpu()
            s2_preds = output_list[1].argmax(1).cpu()
            s3_preds = output_list[2].argmax(1).cpu()
            acc = np.mean((preds == labels.cpu()).numpy())
            s1_acc = np.mean((s1_preds == labels.cpu()).numpy())
            s2_acc = np.mean((s2_preds == labels.cpu()).numpy())
            s3_acc = np.mean((s3_preds == labels.cpu()).numpy())
            if index == 0:
                print(acc)
            loss_mtl = sum(loss_mtl)
            loss = loss_mtl
            test_acc += acc
            test_loss += loss.item()
            """
        #print('Testing: loss %.4f,Acc %.4f ,s1 %.4f,s2 %.4f,s3 %.4f' %(test_loss/len(test_loader),test_acc/len(test_loader),s1_acc,s2_acc,s3_acc))

    return 0
Example #2
0
def main():
    
    #create tensorboard summary writer
    writer = SummaryWriter(args.experiment_id)
    #[TODO] may need to resize input image
    cudnn.enabled = True
    #create model: Encoder
    model_encoder = Encoder()
    model_encoder.train()
    model_encoder.cuda(args.gpu)
    optimizer_encoder = optim.Adam(model_encoder.parameters(), lr=args.learning_rate, betas=(0.95, 0.99))
    optimizer_encoder.zero_grad()

    #create model: Decoder
    model_decoder = Decoder()
    model_decoder.train()
    model_decoder.cuda(args.gpu)
    optimizer_decoder = optim.Adam(model_decoder.parameters(), lr=args.learning_rate, betas=(0.95, 0.99))
    optimizer_decoder.zero_grad()
    
    l2loss = nn.MSELoss()
    
    #load data
    for i in range(1, 360002, 30000):
        train_data, valid_data = get_data(i)
        for e in range(1, args.epoch + 1):
            train_loss_value = 0
            validation_loss_value = 0
            for j in range(0, int(args.train_size/4), args.batch_size):
                optimizer_decoder.zero_grad()
                optimizer_decoder.zero_grad()
                image = Variable(torch.tensor(train_data[j: j + args.batch_size, :, :])).cuda(args.gpu)
                latent = model_encoder(image)
                img_recon = model_decoder(latent)
                img_recon = F.interpolate(img_recon, size=image.shape[2:], mode='bilinear', align_corners=True) 
                loss = l2loss(img_recon, image)
                train_loss_value += loss.data.cpu().numpy() / args.batch_size
                loss.backward()
                optimizer_decoder.step()
                optimizer_encoder.step()
            print("data load: {:8d}".format(i))
            print("epoch: {:8d}".format(e))
            print("train_loss: {:08.6f}".format(train_loss_value / (args.train_size / args.batch_size)))
            for j in range(0,int(args.validation_size/4), args.batch_size):
                model_encoder.eval()
                model_decoder.eval() 
                image = Variable(torch.tensor(valid_data[j: j + args.batch_size, :, :])).cuda(args.gpu)
                latent = model_encoder(image)
                img_recon = model_decoder(latent)
                img_1 = img_recon[0][0]
                img = image[0][0]
                img_recon = F.interpolate(img_recon, size=image.shape[2:], mode='bilinear', align_corners=True) 
                save_image(img_1, args.image_dir + '/fake' + str(i) + "_" + str(j) + ".png")
                save_image(img, args.image_dir + '/real' + str(i) + "_" + str(j) + ".png")
                image = Variable(torch.tensor(train_data[j: j + args.batch_size, :, :, :])).cuda(args.gpu)
                loss = l2loss(img_recon, image)
                validation_loss_value += loss.data.cpu().numpy() / args.batch_size
            model_encoder.train()
            model_decoder.train()
            print("train_loss: {:08.6f}".format(validation_loss_value / (args.validation_size / args.batch_size)))
        torch.save({'encoder_state_dict': model_encoder.state_dict()}, osp.join(args.checkpoint_dir, 'AE_encoder.pth'))
        torch.save({'decoder_state_dict': model_decoder.state_dict()}, osp.join(args.checkpoint_dir, 'AE_decoder.pth'))
class Instructor:
    def __init__(self, model_name: str, args):
        self.model_name = model_name
        self.args = args
        self.encoder = Encoder(self.args.add_noise).to(self.args.device)
        self.decoder = Decoder(self.args.upsample_mode).to(self.args.device)
        self.pretrainDataset = None
        self.pretrainDataloader = None
        self.pretrainOptimizer = None
        self.pretrainScheduler = None
        self.RHO_tensor = None
        self.pretrain_batch_cnt = 0
        self.writer = None
        self.svmDataset = None
        self.svmDataloader = None
        self.testDataset = None
        self.testDataloader = None
        self.svm = SVC(C=self.args.svm_c,
                       kernel=self.args.svm_ker,
                       verbose=True,
                       max_iter=self.args.svm_max_iter)
        self.resnet = Resnet(use_pretrained=True,
                             num_classes=self.args.classes,
                             resnet_depth=self.args.resnet_depth,
                             dropout=self.args.resnet_dropout).to(
                                 self.args.device)
        self.resnetOptimizer = None
        self.resnetScheduler = None
        self.resnetLossFn = None

    def _load_data_by_label(self, label: str) -> list:
        ret = []
        LABEL_PATH = os.path.join(self.args.TRAIN_PATH, label)
        for dir_path, _, file_list in os.walk(LABEL_PATH, topdown=False):
            for file_name in file_list:
                file_path = os.path.join(dir_path, file_name)
                img_np = imread(file_path)
                img = img_np.copy()
                img = img.tolist()
                ret.append(img)
        return ret

    def _load_all_data(self):
        all_data = []
        all_labels = []
        for label_id in range(0, self.args.classes):
            expression = LabelEnum(label_id)
            sub_data = self._load_data_by_label(expression.name)
            sub_labels = [label_id] * len(sub_data)
            all_data.extend(sub_data)
            all_labels.extend(sub_labels)
        return all_data, all_labels

    def _load_test_data(self):
        file_map = pd.read_csv(
            os.path.join(self.args.RAW_PATH, 'submission.csv'))
        test_data = []
        img_names = []
        for file_name in file_map['file_name']:
            file_path = os.path.join(self.args.TEST_PATH, file_name)
            img_np = imread(file_path)
            img = img_np.copy()
            img = img.tolist()
            test_data.append(img)
            img_names.append(file_name)
        return test_data, img_names

    def trainAutoEncoder(self):
        self.writer = SummaryWriter(
            os.path.join(self.args.LOG_PATH, self.model_name))
        all_data, all_labels = self._load_all_data()
        self.pretrainDataset = FERDataset(all_data,
                                          labels=all_labels,
                                          args=self.args)
        self.pretrainDataloader = DataLoader(dataset=self.pretrainDataset,
                                             batch_size=self.args.batch_size,
                                             shuffle=True,
                                             num_workers=self.args.num_workers)
        self.pretrainOptimizer = torch.optim.Adam([{
            'params':
            self.encoder.parameters(),
            'lr':
            self.args.pretrain_lr
        }, {
            'params':
            self.decoder.parameters(),
            'lr':
            self.args.pretrain_lr
        }])
        tot_steps = math.ceil(
            len(self.pretrainDataloader) /
            self.args.cumul_batch) * self.args.epochs
        self.pretrainScheduler = get_linear_schedule_with_warmup(
            self.pretrainOptimizer,
            num_warmup_steps=0,
            num_training_steps=tot_steps)
        self.RHO_tensor = torch.tensor(
            [self.args.rho for _ in range(self.args.embed_dim)],
            dtype=torch.float).unsqueeze(0).to(self.args.device)
        epochs = self.args.epochs
        for epoch in range(1, epochs + 1):
            print()
            print(
                "================ AutoEncoder Training Epoch {:}/{:} ================"
                .format(epoch, epochs))
            print(" ---- Start training ------>")
            self.epochTrainAutoEncoder(epoch)
            print()
        self.writer.close()

    def epochTrainAutoEncoder(self, epoch):
        self.encoder.train()
        self.decoder.train()

        cumul_loss = 0
        cumul_steps = 0
        cumul_samples = 0

        self.pretrainOptimizer.zero_grad()
        cumulative_batch = 0

        for idx, (images, labels) in enumerate(tqdm(self.pretrainDataloader)):
            batch_size = images.shape[0]
            images, labels = images.to(self.args.device), labels.to(
                self.args.device)

            embeds = self.encoder(images)
            outputs = self.decoder(embeds)

            loss = torch.nn.functional.mse_loss(outputs, images)
            if self.args.use_sparse:
                rho_hat = torch.mean(embeds, dim=0, keepdim=True)
                sparse_penalty = self.args.regulizer_weight * torch.nn.functional.kl_div(
                    input=torch.nn.functional.log_softmax(rho_hat, dim=-1),
                    target=torch.nn.functional.softmax(self.RHO_tensor,
                                                       dim=-1))
                loss = loss + sparse_penalty

            loss_each = loss / self.args.cumul_batch
            loss_each.backward()

            cumulative_batch += 1
            cumul_steps += 1
            cumul_loss += loss.detach().cpu().item() * batch_size
            cumul_samples += batch_size

            if cumulative_batch >= self.args.cumul_batch:
                torch.nn.utils.clip_grad_norm_(self.encoder.parameters(),
                                               max_norm=self.args.max_norm)
                torch.nn.utils.clip_grad_norm_(self.decoder.parameters(),
                                               max_norm=self.args.max_norm)
                self.pretrainOptimizer.step()
                self.pretrainScheduler.step()
                self.pretrainOptimizer.zero_grad()
                cumulative_batch = 0

            if cumul_steps >= self.args.disp_period or idx + 1 == len(
                    self.pretrainDataloader):
                print(" -> cumul_steps={:} loss={:}".format(
                    cumul_steps, cumul_loss / cumul_samples))
                self.pretrain_batch_cnt += 1
                self.writer.add_scalar('batch-loss',
                                       cumul_loss / cumul_samples,
                                       global_step=self.pretrain_batch_cnt)
                self.writer.add_scalar('encoder_lr',
                                       self.pretrainOptimizer.state_dict()
                                       ['param_groups'][0]['lr'],
                                       global_step=self.pretrain_batch_cnt)
                self.writer.add_scalar('decoder_lr',
                                       self.pretrainOptimizer.state_dict()
                                       ['param_groups'][1]['lr'],
                                       global_step=self.pretrain_batch_cnt)
                cumul_steps = 0
                cumul_loss = 0
                cumul_samples = 0

        self.saveAutoEncoder(epoch)

    def saveAutoEncoder(self, epoch):
        encoderPath = os.path.join(
            self.args.CKPT_PATH,
            self.model_name + "--Encoder" + "--EPOCH-{:}".format(epoch))
        decoderPath = os.path.join(
            self.args.CKPT_PATH,
            self.model_name + "--Decoder" + "--EPOCH-{:}".format(epoch))
        print("-----------------------------------------------")
        print("  -> Saving AutoEncoder {:} ......".format(self.model_name))
        torch.save(self.encoder.state_dict(), encoderPath)
        torch.save(self.decoder.state_dict(), decoderPath)
        print("  -> Successfully saved AutoEncoder.")
        print("-----------------------------------------------")

    def generateAutoEncoderTestResultSamples(self, sample_cnt):
        self.encoder.eval()
        self.decoder.eval()
        print('  -> Generating samples with AutoEncoder ...')
        save_path = os.path.join(self.args.SAMPLE_PATH, self.model_name)
        if not os.path.exists(save_path):
            os.mkdir(save_path)
        with torch.no_grad():
            for dir_path, _, file_list in os.walk(self.args.TEST_PATH,
                                                  topdown=False):
                sample_file_list = random.choices(file_list, k=sample_cnt)
                for file_name in sample_file_list:
                    file_path = os.path.join(dir_path, file_name)
                    img_np = imread(file_path)
                    img = img_np.copy()
                    img = ToTensor()(img)
                    img = img.reshape(1, 1, 48, 48)
                    img = img.to(self.args.device)
                    embed = self.encoder(img)
                    out = self.decoder(embed).cpu()
                    out = out.reshape(1, 48, 48)
                    out_img = ToPILImage()(out)
                    out_img.save(os.path.join(save_path, file_name))
        print('  -> Done sampling from AutoEncoder with test pictures.')

    def loadAutoEncoder(self, epoch):
        encoderPath = os.path.join(
            self.args.CKPT_PATH,
            self.model_name + "--Encoder" + "--EPOCH-{:}".format(epoch))
        decoderPath = os.path.join(
            self.args.CKPT_PATH,
            self.model_name + "--Decoder" + "--EPOCH-{:}".format(epoch))
        print("-----------------------------------------------")
        print("  -> Loading AutoEncoder {:} ......".format(self.model_name))
        self.encoder.load_state_dict(torch.load(encoderPath))
        self.decoder.load_state_dict(torch.load(decoderPath))
        print("  -> Successfully loaded AutoEncoder.")
        print("-----------------------------------------------")

    def generateExtractedFeatures(
            self, data: torch.FloatTensor) -> torch.FloatTensor:
        """
        :param data: (batch, channel, l, w)
        :return: embed: (batch, embed_dim)
        """
        with torch.no_grad():
            data = data.to(self.args.device)
            embed = self.encoder(data)
            embed = embed.detach().cpu()
            return embed

    def trainSVM(self, load: bool):
        svm_path = os.path.join(self.args.CKPT_PATH, self.model_name + '--svm')
        self.loadAutoEncoder(self.args.epochs)
        self.encoder.eval()
        self.decoder.eval()
        if load:
            print('  -> Loaded from SVM trained model.')
            self.svm = joblib.load(svm_path)
            return
        print()
        print("================ SVM Training Starting ================")
        all_data, all_labels = self._load_all_data()
        all_length = len(all_data)
        self.svmDataset = FERDataset(all_data,
                                     labels=all_labels,
                                     use_da=False,
                                     args=self.args)
        self.svmDataloader = DataLoader(dataset=self.svmDataset,
                                        batch_size=self.args.batch_size,
                                        shuffle=False,
                                        num_workers=self.args.num_workers)
        print("  -> Converting to extracted features ...")
        cnt = 0
        all_embeds = []
        all_labels = []
        for images, labels in self.svmDataloader:
            cnt += 1
            embeds = self.generateExtractedFeatures(images)
            all_embeds.extend(embeds.tolist())
            all_labels.extend(labels.reshape(-1).tolist())
        print('  -> Start SVM fit ...')
        self.svm.fit(X=all_embeds, y=all_labels)
        # self.svm.fit(X=all_embeds[0:3], y=[0, 1, 2])
        joblib.dump(self.svm, svm_path)
        print("  -> Done training for SVM.")

    def genTestResult(self, from_svm=True):
        print()
        print('-------------------------------------------------------')
        print('  -> Generating test result for {:} ...'.format(
            'SVM' if from_svm else 'Resnet'))
        test_data, img_names = self._load_test_data()
        test_length = len(test_data)
        self.testDataset = FERDataset(test_data,
                                      filenames=img_names,
                                      use_da=False,
                                      args=self.args)
        self.testDataloader = DataLoader(dataset=self.testDataset,
                                         batch_size=self.args.batch_size,
                                         shuffle=False,
                                         num_workers=self.args.num_workers)
        str_preds = []
        for images, filenames in self.testDataloader:
            if from_svm:
                embeds = self.generateExtractedFeatures(images)
                preds = self.svm.predict(X=embeds)
            else:
                self.resnet.eval()
                outs = self.resnet(
                    images.repeat(1, 3, 1, 1).to(self.args.device))
                preds = outs.max(-1)[1].cpu().tolist()
            str_preds.extend([LabelEnum(pred).name for pred in preds])
        # generate submission
        assert len(str_preds) == len(img_names)
        submission = pd.DataFrame({'file_name': img_names, 'class': str_preds})
        submission.to_csv(os.path.join(self.args.DATA_PATH, 'submission.csv'),
                          index=False,
                          index_label=False)
        print('  -> Done generation of submission.csv with model {:}'.format(
            self.model_name))

    def epochTrainResnet(self, epoch):
        self.resnet.train()

        cumul_loss = 0
        cumul_acc = 0
        cumul_steps = 0
        cumul_samples = 0
        cumulative_batch = 0

        self.resnetOptimizer.zero_grad()

        for idx, (images, labels) in enumerate(tqdm(self.pretrainDataloader)):
            batch_size = images.shape[0]
            images, labels = images.to(self.args.device), labels.to(
                self.args.device)
            images += torch.randn(images.shape).to(
                images.device) * self.args.add_noise
            images = images.repeat(1, 3, 1, 1)

            outs = self.resnet(images)
            preds = outs.max(-1)[1].unsqueeze(dim=1)
            cur_acc = (preds == labels).type(torch.int).sum().item()

            loss = self.resnetLossFn(outs, labels.squeeze(dim=1))

            loss_each = loss / self.args.cumul_batch
            loss_each.backward()

            cumulative_batch += 1
            cumul_steps += 1
            cumul_loss += loss.detach().cpu().item() * batch_size
            cumul_acc += cur_acc
            cumul_samples += batch_size

            if cumulative_batch >= self.args.cumul_batch:
                torch.nn.utils.clip_grad_norm_(self.resnet.parameters(),
                                               max_norm=self.args.max_norm)
                self.resnetOptimizer.step()
                self.resnetScheduler.step()
                self.resnetOptimizer.zero_grad()
                cumulative_batch = 0

            if cumul_steps >= self.args.disp_period or idx + 1 == len(
                    self.pretrainDataloader):
                print(" -> cumul_steps={:} loss={:} acc={:}".format(
                    cumul_steps, cumul_loss / cumul_samples,
                    cumul_acc / cumul_samples))
                self.pretrain_batch_cnt += 1
                self.writer.add_scalar('batch-loss',
                                       cumul_loss / cumul_samples,
                                       global_step=self.pretrain_batch_cnt)
                self.writer.add_scalar('batch-acc',
                                       cumul_acc / cumul_samples,
                                       global_step=self.pretrain_batch_cnt)
                self.writer.add_scalar(
                    'resnet_lr',
                    self.resnetOptimizer.state_dict()['param_groups'][0]['lr'],
                    global_step=self.pretrain_batch_cnt)
                cumul_steps = 0
                cumul_loss = 0
                cumul_acc = 0
                cumul_samples = 0

        if epoch % 10 == 0:
            self.saveResnet(epoch)

    def saveResnet(self, epoch):
        resnetPath = os.path.join(
            self.args.CKPT_PATH,
            self.model_name + "--Resnet" + "--EPOCH-{:}".format(epoch))
        print("-----------------------------------------------")
        print("  -> Saving Resnet {:} ......".format(self.model_name))
        torch.save(self.resnet.state_dict(), resnetPath)
        print("  -> Successfully saved Resnet.")
        print("-----------------------------------------------")

    def loadResnet(self, epoch):
        resnetPath = os.path.join(
            self.args.CKPT_PATH,
            self.model_name + "--Resnet" + "--EPOCH-{:}".format(epoch))
        print("-----------------------------------------------")
        print("  -> Loading Resnet {:} ......".format(self.model_name))
        self.resnet.load_state_dict(torch.load(resnetPath))
        print("  -> Successfully loaded Resnet.")
        print("-----------------------------------------------")

    def trainResnet(self):
        self.writer = SummaryWriter(
            os.path.join(self.args.LOG_PATH, self.model_name))
        all_data, all_labels = self._load_all_data()
        self.pretrainDataset = FERDataset(all_data,
                                          labels=all_labels,
                                          args=self.args)
        self.pretrainDataloader = DataLoader(dataset=self.pretrainDataset,
                                             batch_size=self.args.batch_size,
                                             shuffle=True,
                                             num_workers=self.args.num_workers)
        self.resnetOptimizer = self.getResnetOptimizer()
        tot_steps = math.ceil(
            len(self.pretrainDataloader) /
            self.args.cumul_batch) * self.args.epochs
        self.resnetScheduler = get_linear_schedule_with_warmup(
            self.resnetOptimizer,
            num_warmup_steps=tot_steps * self.args.warmup_rate,
            num_training_steps=tot_steps)
        self.resnetLossFn = torch.nn.CrossEntropyLoss(
            weight=torch.tensor([
                9.40661861,
                1.00104606,
                0.56843877,
                0.84912748,
                1.02660468,
                1.29337298,
                0.82603942,
            ],
                                dtype=torch.float,
                                device=self.args.device))
        epochs = self.args.epochs
        for epoch in range(1, epochs + 1):
            print()
            print(
                "================ Resnet Training Epoch {:}/{:} ================"
                .format(epoch, epochs))
            print(" ---- Start training ------>")
            self.epochTrainResnet(epoch)
            print()
        self.writer.close()

    def getResnetOptimizer(self):
        if self.args.resnet_optim == 'SGD':
            return torch.optim.SGD([{
                'params': self.resnet.baseParameters(),
                'lr': self.args.resnet_base_lr,
                'weight_decay': self.args.weight_decay,
                'momentum': self.args.resnet_momentum
            }, {
                'params': self.resnet.finetuneParameters(),
                'lr': self.args.resnet_ft_lr,
                'weight_decay': self.args.weight_decay,
                'momentum': self.args.resnet_momentum
            }],
                                   lr=self.args.resnet_base_lr)
        elif self.args.resnet_optim == 'Adam':
            return torch.optim.Adam([{
                'params': self.resnet.baseParameters(),
                'lr': self.args.resnet_base_lr
            }, {
                'params': self.resnet.finetuneParameters(),
                'lr': self.args.resnet_ft_lr,
                'weight_decay': self.args.weight_decay
            }])