def __init__(self):
     super(ToyPredictor, self).__init__()
     self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
     self.is_ensamble = True
     if not self.is_ensamble:
         # model, _, *_ = model_selection('se_resnext101_32x4d', num_out_classes=2, dropout=0.5)
         model = get_efficientnet(model_name='efficientnet-b0',
                                  num_classes=2,
                                  pretrained=False)
         model_path = './weight/output_my_aug/efn-b0_LS_27_loss_0.2205.pth'
         model.load_state_dict(
             torch.load(model_path, map_location=self.device))
         print('Load model in:', model_path)
         self.model = model.to(self.device)
     else:
         self.models = self.load_models(
             model_names=[
                 'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2'
             ],
             model_paths=[
                 './weight/output_my_aug/efn-b0_LS_27_loss_0.2205.pth',
                 './weight/output_my_aug/efn-b1_LS_6_loss_0.1756.pth',
                 './weight/output_my_aug/efn-b2_LS_12_loss_0.1728.pth'
             ])
     self.mtcnn = MTCNN(margin=14,
                        keep_all=True,
                        factor=0.6,
                        device=self.device).eval()
    def load_models(self, model_names, model_paths):
        models = []
        for i in range(len(model_names)):
            model = get_efficientnet(model_name=model_names[i],
                                     num_classes=2,
                                     pretrained=False)
            model_path = model_paths[i]
            model.load_state_dict(
                torch.load(model_path, map_location=self.device))
            print('Load model ', i, 'in:', model_path)
            model.to(self.device)
            models.append(model)

        return models
if __name__ == '__main__':
    LOG_FREQ = 50
    batch_size = 128
    test_batch_size = 128
    device_id = 0
    lr = 1e-3
    epoch_start = 1
    num_epochs = epoch_start + 50
    model_name = 'efficientnet-b1'
    writeFile = '/data1/cby/temp/output_my_aug/logs/' + model_name
    store_name = '/data1/cby/temp/output_my_aug/weights/' + model_name
    if not os.path.isdir(store_name):
        os.makedirs(store_name)
    model_path = None
    # model_path = '/data1/cby/temp/output_my_aug/weights/efficientnet-b1/efn-b1_LS_9_loss_0.1610.pth'
    model = get_efficientnet(model_name=model_name)
    if model_path is not None:
        # model = torch.load(model_path)
        model.load_state_dict(torch.load(model_path, map_location='cpu'))
        print('Model found in {}'.format(model_path))
    else:
        print('No model found, initializing random model.')
    model = model.cuda(device_id)
    train_logger = Logger(model_name=writeFile,
                          header=['epoch', 'loss', 'acc', 'lr'])

    # criterion = nn.CrossEntropyLoss()
    criterion = LabelSmoothing(smoothing=0.05).cuda(device_id)
    # optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    # optimizer = optim.Adam(model.parameters(), lr=lr)
    optimizer = optim.AdamW(model.parameters(), lr=lr)