def train_autoencoder_deep_fashion(): """Trains the autoencoder for DeepFashion.""" print("=============================================================") print("================ Train AE with DeepFashion ==================") print("=============================================================\n") encoder = CifarNet(input_channels=3, num_classes=50) encoder = encoder.to(DEVICE) decoder = Decoder(input_channels=64, num_classes=50, out_channels=3) decoder = decoder.to(DEVICE) parameters = list(encoder.parameters()) + list(decoder.parameters()) loss_fn = nn.MSELoss() # Observe that all parameters are being optimized optimizer = torch.optim.Adam(parameters, lr=LEARNING_RATE_TRAIN) # Decay LR by a factor of GAMMA scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=STEP_SIZE_TRAIN, gamma=GAMMA) return train(encoder, decoder, loss_fn, optimizer, scheduler, EPOCHS, train_loader_deep_fashion)
def train_exemplar_cnn_deep_fashion(): """Trains the exemplar cnn model.""" print("============================================================") print("============ Train ExemplarCNN with DeepFashion ============") print("============================================================\n") # number of predicted classes = number of training images model = CifarNet(input_channels=3, num_classes=len(train_loader_deep_fashion.dataset)) model = model.to(DEVICE) loss_fn = nn.CrossEntropyLoss() # Observe that all parameters are being optimized parameters = model.parameters() optimizer = torch.optim.Adam(parameters, lr=LEARNING_RATE_TRAIN) # Decay LR by a factor of GAMMA scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=STEP_SIZE_TRAIN, gamma=GAMMA) return train(model, loss_fn, optimizer, scheduler, EPOCHS, train_loader_deep_fashion)
def train_rotation_net_deep_fashion(): """Trains the rotation model.""" print("============================================================") print("========== Train Rotation Model with DeepFashion ===========") print("============================================================\n") model = CifarNet(input_channels=3, num_classes=4) model = model.to(DEVICE) loss_fn = nn.CrossEntropyLoss() # Observe that all parameters are being optimized optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE_TRAIN) # Decay LR by a factor of GAMMA scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=STEP_SIZE_TRAIN, gamma=GAMMA) return train(model, loss_fn, optimizer, scheduler, EPOCHS, train_loader_deep_fashion, val_loader_deep_fashion)
def train_supervised_deep_fashion(): """Trains the supervised model.""" print("============================================================") print("============= Supervised Training DeepFashion ==============") print("============================================================\n") df_supervised_model = CifarNet(input_channels=3, num_classes=50) df_supervised_model = df_supervised_model.to(DEVICE) loss_fn = nn.CrossEntropyLoss() # Observe that all parameters are being optimized optimizer = torch.optim.Adam(df_supervised_model.parameters(), lr=LEARNING_RATE_TRAIN) # Decay LR by a factor of GAMMA scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=STEP_SIZE_TRAIN, gamma=GAMMA) return fine_tune(df_supervised_model, loss_fn, optimizer, scheduler, EPOCHS, train_loader_deep_fashion, val_loader_deep_fashion)