def main_with_angularloss(model, epoch, data_name): """ train model :param model: :param epoch: :param data_name: :return: """ criterion_aloss = AngularLoss() optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4) exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=60, gamma=0.1) if data_name == 'TissuePhysiology': print('start loading TissuePhysiology dataset...') trainloader, valloader, testloader = data_loader.load_tissuephysiology_data() elif data_name == 'LightClothing': print('start loading LightClothing dataset...') trainloader, valloader, testloader = data_loader.load_lightclothing_data() else: print('Invalid data name. It can only be TissuePhysiology or LightClothing...') dataloaders = { 'train': trainloader, 'val': valloader, 'test': testloader } train_model(model=model, dataloaders=dataloaders, criterion=criterion_aloss, optimizer=optimizer, scheduler=exp_lr_scheduler, num_epochs=epoch, inference=False)
def run_light_clothing_recognition(model, epoch): """ run light clothing recognition :param model: :param epoch: :return: """ criterion = nn.CrossEntropyLoss() optimizer_ft = optim.SGD(model.parameters(), lr=cfg['init_lr'], momentum=0.9, weight_decay=cfg['weight_decay']) exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=cfg['lr_decay_step'], gamma=0.1) print('start loading LightClothingDataset...') trainloader, valloader, testloader = data_loader.load_lightclothing_data() dataloaders = { 'train': trainloader, 'val': valloader, 'test': testloader, } train_model(model=model, dataloaders=dataloaders, criterion=criterion, optimizer=optimizer_ft, scheduler=exp_lr_scheduler, num_epochs=epoch, inference=False)
def main_with_centerloss(model, epoch, data_name): """ train model :param model: :param epoch: :param data_name: ISIC/SD198 :return: """ criterion_xent = nn.CrossEntropyLoss() criterion_cent = CenterLoss(num_classes=cfg['out_num'], feat_dim=1024) optimizer_model = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4) optimizer_centloss = optim.SGD(criterion_cent.parameters(), lr=0.5) exp_lr_scheduler = lr_scheduler.StepLR(optimizer_model, step_size=60, gamma=0.1) if data_name == 'TissuePhysiology': print('start loading TissuePhysiology dataset...') trainloader, valloader, testloader = data_loader.load_tissuephysiology_data() elif data_name == 'LightClothing': print('start loading LightClothing dataset...') trainloader, valloader, testloader = data_loader.load_lightclothing_data() else: print('Invalid data name. It can only be TissuePhysiology or LightClothing...') dataloaders = { 'train': trainloader, 'val': valloader, 'test': testloader } train_model(model=model, dataloaders=dataloaders, criterion_xent=criterion_xent, criterion_cent=criterion_cent, optimizer_model=optimizer_model, optimizer_centloss=optimizer_centloss, scheduler=exp_lr_scheduler, num_epochs=epoch, inference=False)