def __init__(self, train_loader, valid_loader, test_loader, model, margin_penalty, train_loss_fn, test_loss_fn, sim_fn, device): self.train_loader = train_loader self.val_loader = valid_loader self.test_loader = test_loader self.model = model self.test_model = TripletNet(model) if torch.cuda.device_count() > 1: self.model = nn.DataParallel(self.model) self.test_model = nn.DataParallel(self.test_model) self.margin_penalty = margin_penalty self.train_loss_fn = train_loss_fn #ロス関数 self.test_loss_fn = test_loss_fn #ロス関数 self.sim_fn = sim_fn #類似度関数 self.device = device
CLASSF_MODEL_NAME = "2dCNNL2_TripletLoss_classf.pt" """ Load Training data""" # Load training data to list training, training_labels = load_glued_image(GLUED_IMAGE_PATH, NUM_CLASS, 0, IMAGE_SIZE) # Load the negative set of images negative, negative_labels = load_data(NEGATIVE_DIR, n_class=NUM_CLASS, cls=5, img_size=IMAGE_SIZE) print("-" * 30) print("Data loaded!!") print("-" * 30) # 3 models extracting the pattern feature triplet_mdoel = TripletNet(CNNEmbeddingNetL2()).cuda() # pattern_model = TripletNet( # torch.hub.load('pytorch/vision:v0.9.0', 'mobilenet_v2', pretrained=True) # ).cuda() criterion = TripletLoss(margin=MARGIN) # optimizer = optim.Adam(triplet_model.parameters(), lr=LR) optimizer = optim.SGD(triplet_mdoel.embedding_net.parameters(), lr=LR, momentum=0.9) """Training Phase""" """ Stage 1: Train the Embedding Network""" triplet_model = train_triplet(triplet_mdoel, criterion,
for model_type in ["pattern", "color", "shape"]: # Load paths and create Pytorch dataset training_paths, training_labels = load_data(TRAINING_DIR, NUM_CLASS) training = TripletVaseDataset(VaseDataset(training_paths, training_labels, IMAGE_SIZE, CROP_SIZE, model_type)) # Make data loaders for the clustered CNN training_loader = DataLoader(training, batch_size=TRAINING_BATCH_SIZE, shuffle=True) print("-" * 30) print(f"{model_type} data loaded!!") print("-" * 30) in_channel = 2 if model_type == "color" else 3 model = TripletNet(CNNEmbeddingNetL2(in_channel, 128)).cuda() # pattern_model = TripletNet( # torch.hub.load('pytorch/vision:v0.9.0', 'mobilenet_v2', pretrained=True) # ).cuda() """Training Phase""" print("-" * 30) print(f"Training {model_type} Model") print("-" * 30) criterion = TripletLoss(margin=MARGIN) # TODO: Can trial on ADAM optimizers optimizer = optim.SGD(model.embedding_net.parameters(), lr=LR, momentum=0.9) model = train_triplet(model, criterion, optimizer, training_loader, n_epoch=N_EPOCH) torch.save(model, os.path.join(MODEL_ROOT_DIR, f"{TRIPLET_MODEL_NAME}_{model_type}_v{TRIPLET_MODEL_VERSION}.pt"))
def main(args): assert args.save_interval % 10 == 0, "save_interval must be a multiple of 10" # prepare dirs os.makedirs(args.log_dir, exist_ok=True) os.makedirs(args.save_model, exist_ok=True) device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu") print("Device is", device) # img path loading with open("data/3d_data.pkl", mode='rb') as f: data_3d = pickle.load(f) train_path_list = data_3d.train_pl val_path_list = data_3d.val_pl train_dataset = TripletDataset(transform=ImageTransform(), flist=train_path_list) train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) val_dataset = TripletDataset(transform=ImageTransform(), flist=val_path_list) val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False) model = TripletNet() model.to(device) criterion = nn.MarginRankingLoss(margin=args.margin) # choose params to train update_params_name = [] for name, _ in model.named_parameters(): if 'layer4' in name: update_params_name.append(name) elif 'fc' in name: update_params_name.append(name) print("**-----** update params **-----**") print(update_params_name) print("**-----------------------------**") print() params_to_update = choose_update_params(update_params_name, model) # set optimizer optimizer = optim.SGD(params_to_update, lr=1e-4, momentum=0.9) # run epoch log_writer = SummaryWriter(log_dir=args.log_dir) for epoch in range(args.num_epochs): print("-"*80) print('Epoch {}/{}'.format(epoch+1, args.num_epochs)) epoch_loss, epoch_acc = [], [] for inputs, labels in tqdm(train_dataloader): batch_loss, batch_acc = train_one_batch(inputs, labels, model, criterion, optimizer, device) epoch_loss.append(batch_loss.item()) epoch_acc.append(batch_acc.item()) epoch_loss = np.array(epoch_loss) epoch_acc = np.array(epoch_acc) print('[Loss: {:.4f}], [Acc: {:.4f}] \n'.format(np.mean(epoch_loss), np.mean(epoch_acc))) log_writer.add_scalar("train/loss", np.mean(epoch_loss), epoch+1) log_writer.add_scalar("train/acc", np.mean(epoch_acc), epoch+1) # validation if (epoch+1) % 10 == 0: print("Run Validation") epoch_loss, epoch_acc = [], [] for inputs, labels in tqdm(val_dataloader): batch_loss, batch_acc = validation(inputs, labels, model, criterion, device) epoch_loss.append(batch_loss.item()) epoch_acc.append(batch_acc.item()) epoch_loss = np.array(epoch_loss) epoch_acc = np.array(epoch_acc) print('[Validation Loss: {:.4f}], [Validation Acc: {:.4f}]'.format(np.mean(epoch_loss), np.mean(epoch_acc))) log_writer.add_scalar("val/loss", np.mean(epoch_loss), epoch+1) log_writer.add_scalar("val/acc", np.mean(epoch_acc), epoch+1) # save model if (args.save_interval > 0) and ((epoch+1) % args.save_interval == 0): save_path = os.path.join(args.save_model, '{}_epoch_{:.1f}.pth'.format(epoch+1, np.mean(epoch_loss))) torch.save(model.state_dict(), save_path) log_writer.close()
class UIRTrainer: def __init__(self, sup_train_loader, semisup_train_loader, sup_valid_loader, semisup_valid_loader, sup_test_loader, semisup_test_loader, \ model, margin_penalty, sup_train_loss_fn, semisup_train_loss_fn, test_loss_fn, sim_fn, device): self.sup_train_loader = sup_train_loader self.sup_val_loader = sup_valid_loader self.sup_test_loader = sup_test_loader self.semisup_train_loader = semisup_train_loader self.semisup_val_loader = semisup_valid_loader self.semisup_test_loader = semisup_test_loader self.model = model self.test_model = TripletNet(model) if torch.cuda.device_count() > 1: self.model = nn.DataParallel(self.model) self.test_model = nn.DataParallel(self.test_model) self.margin_penalty = margin_penalty self.sup_train_loss_fn = sup_train_loss_fn #ロス関数 self.semisup_train_loss_fn = semisup_train_loss_fn #ロス関数 self.test_loss_fn = test_loss_fn #ロス関数 self.sim_fn = sim_fn #類似度関数 self.device = device def fit(self, lr, n_epochs, log_interval, save_epoch_interval, start_epoch=0, outdir="../result/checkpoint/", data_dirname=None): """ Loaders, model, loss function and metrics should work together for a given task, i.e. The model should be able to process data output of loaders, loss function should process target output of loaders and outputs from the model Examples: Classification: batch loader, classification model, NLL loss, accuracy metric Siamese network: Siamese loader, siamese model, contrastive loss Online triplet learning: batch loader, embedding model, online triplet loss """ sup_optimizer = optim.Adam([{ 'params': self.model.parameters() }, { 'params': self.margin_penalty.parameters() }], lr=lr) sup_scheduler = lr_scheduler.StepLR(sup_optimizer, 8, gamma=0.1, last_epoch=-1) semisup_optimizer = optim.Adam( [{ 'params': self.model.parameters() }, { 'params': self.margin_penalty.parameters() }], lr=lr) semisup_scheduler = lr_scheduler.StepLR(semisup_optimizer, 8, gamma=0.1, last_epoch=-1) n_epochs *= 2 #教師あり学習と半教師あり学習の2回行うため if start_epoch != 0: embedding_model = torch.load( f"{outdir}{data_dirname}_embeddingNet_out{self.model.num_out}_epoch{start_epoch-1}.pth" ) model = torch.load( f"{outdir}{data_dirname}_model_out{self.model.num_out}_epoch{start_epoch-1}.pth" ) margin_penalty = torch.load( f"{outdir}{data_dirname}_marginPenalty_out{self.model.num_out}_epoch{start_epoch-1}.pth" ) self.model.load_state_dict(model) self.margin_penalty.load_state_dict(margin_penalty) for epoch in range(0, start_epoch): if epoch < n_epochs / 2: sup_scheduler.step() else: semisup_scheduler.step() for epoch in range(start_epoch, n_epochs): # Train stage if epoch < n_epochs / 2: sup_scheduler.step() train_loss = self.sup_train_epoch(sup_optimizer, log_interval) else: semisup_scheduler.step() train_loss = self.semisup_train_epoch(semisup_optimizer, log_interval) message = 'Epoch: {}/{}\n\tTrain set: Average loss: {:.4f}'.format( epoch + 1, n_epochs, train_loss) # Validation stage sup_val_loss, semisup_val_loss, sup_val_acc, semisup_val_acc = self.validation_epoch( ) sup_val_loss /= len(self.sup_val_loader) semisup_val_loss /= len(self.semisup_val_loader) message += '\n\tValidation set: Average loss: labeled{:.6f}, unlabeled{:.6f}'.format( sup_val_loss, semisup_val_loss) message += '\n\t Accuracy rate: labeled{:.6f}%, unlabeled{:.6f}%'.format( sup_val_acc, semisup_val_acc) # Test stage sup_test_loss, semisup_test_loss, sup_test_acc, semisup_test_acc = self.test_epoch( ) message += '\n\tTest set: Average loss: labeled{:.6f}, unlabeled{:.6f}'.format( sup_test_loss, semisup_test_loss) message += '\n\t Accuracy rate: labeled{:.6f}%, unlabeled{:.6f}%'.format( sup_test_acc, semisup_test_acc) logging.info(message + "\n") if data_dirname is not None and (epoch + 1) % save_epoch_interval == 0: torch.save( self.model.state_dict(), f"{outdir}{data_dirname}_embeddingNet_out{self.model.num_out}_epoch%d.pth" % epoch) torch.save( self.model.state_dict(), f"{outdir}{data_dirname}_model_out{self.model.num_out}_epoch%d.pth" % epoch) torch.save( self.margin_penalty.state_dict(), f"{outdir}{data_dirname}_marginPenalty_out{self.model.num_out}_epoch%d.pth" % epoch) train_loss = train_loss if float(train_loss) != 0.0 else 10000.0 return train_loss def sup_train_epoch(self, optimizer, log_interval): self.model.train() losses = [] total_loss = 0 for batch_idx, (data, target) in enumerate(self.sup_train_loader): data = data.to(self.device) target = target.to(self.device).long() optimizer.zero_grad() outputs = self.model(data) outputs = self.margin_penalty(outputs, target) loss_outputs = self.sup_train_loss_fn(outputs, target) loss = loss_outputs[0] if type(loss_outputs) in ( tuple, list) else loss_outputs losses.append(loss.item()) total_loss += loss.item() loss.backward() optimizer.step() if batch_idx % log_interval == 0: message = 'Train: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( batch_idx * len(data), len(self.sup_train_loader.dataset), 100. * batch_idx / len(self.sup_train_loader), np.mean(losses)) logging.info(message) losses = [] total_loss /= (batch_idx + 1) return total_loss def semisup_train_epoch(self, optimizer, log_interval): self.model.train() labeled_losses = [] unlabeled_losses = [] losses = [] total_loss = 0 for batch_idx, ((labeled_data, labeled_target), (unlabeled_data, unlabeled_target)) in enumerate( zip(self.sup_train_loader, self.semisup_train_loader)): labeled_data, unlabeled_data = labeled_data.to( self.device), unlabeled_data.to(self.device) labeled_target, unlabeled_target = labeled_target.to( self.device).long(), unlabeled_target.to(self.device).long() optimizer.zero_grad() labeled_outputs, unlabeled_outputs = self.model( labeled_data), self.model(unlabeled_data) labeled_outputs, unlabeled_outputs \ = self.margin_penalty(labeled_outputs, labeled_target), self.margin_penalty(unlabeled_outputs, unlabeled_target) labeled_loss_outputs, unlabeled_loss_outputs \ = self.sup_train_loss_fn(labeled_outputs, labeled_target), self.semisup_train_loss_fn(unlabeled_outputs) labeled_loss = labeled_loss_outputs[0] if type( labeled_loss_outputs) in (tuple, list) else labeled_loss_outputs unlabeled_loss = unlabeled_loss_outputs[0] if type( unlabeled_loss_outputs) in (tuple, list) else unlabeled_loss_outputs labeled_losses.append(labeled_loss.item()) unlabeled_losses.append(unlabeled_loss.item()) loss = labeled_loss + unlabeled_loss total_loss += loss loss.backward() optimizer.step() if batch_idx % log_interval == 0: message = 'Train: [{}/{}, {}/{} ({:.0f}%)]\tLoss: labeled{:.6f}, unlabeled{:.6f}'.format( batch_idx * len(labeled_data), len(self.sup_train_loader.dataset), batch_idx * len(unlabeled_data), len(self.semisup_train_loader.dataset), 100. * batch_idx / len(self.sup_train_loader), np.mean(labeled_losses), np.mean(unlabeled_losses)) logging.info(message) losses = [] labeled_losses = [] unlabeled_losses = [] total_loss /= (batch_idx + 1) return total_loss def validation_epoch(self): with torch.no_grad(): self.test_model.eval() accuracy_rates = list() val_losses = list() for val_loader in [self.sup_val_loader, self.semisup_val_loader]: val_loss = 0 n_true = 0 for batch_idx, (data, _) in enumerate(val_loader): if not type(data) in (tuple, list): data = (data, ) data = tuple(d.to(self.device) for d in data) outputs = self.test_model(*data) if type(outputs) not in (tuple, list): outputs = (outputs, ) loss_inputs = outputs loss_outputs = self.test_loss_fn(*loss_inputs) loss = loss_outputs[0] if type(loss_outputs) in ( tuple, list) else loss_outputs val_loss += loss.item() pos_dist, neg_dist = self.sim_fn(*loss_inputs) for i in range(len(pos_dist)): n_true += 1 if pos_dist[i] < neg_dist[i] else 0 val_losses.append(val_loss) accuracy_rates.append((n_true / len(val_loader.dataset)) * 100) sup_val_loss, semisup_val_loss = val_losses sup_accuracy_rate, semisup_accuracy_rate = accuracy_rates return sup_val_loss, semisup_val_loss, sup_accuracy_rate, semisup_accuracy_rate def test_epoch(self): with torch.no_grad(): self.test_model.eval() accuracy_rates = list() test_losses = list() for test_loader in [ self.sup_test_loader, self.semisup_test_loader ]: test_loss = 0 n_true = 0 for batch_idx, (data, _) in enumerate(test_loader): if not type(data) in (tuple, list): data = (data, ) data = tuple(d.to(self.device) for d in data) outputs = self.test_model(*data) if type(outputs) not in (tuple, list): outputs = (outputs, ) loss_inputs = outputs loss_outputs = self.test_loss_fn(*loss_inputs) loss = loss_outputs[0] if type(loss_outputs) in ( tuple, list) else loss_outputs test_loss += loss.item() pos_dist, neg_dist = self.sim_fn(*loss_inputs) for i in range(len(pos_dist)): n_true += 1 if pos_dist[i] < neg_dist[i] else 0 test_losses.append(test_loss) accuracy_rates.append( (n_true / len(test_loader.dataset)) * 100) sup_test_loss, semisup_test_loss = test_losses sup_accuracy_rate, semisup_accuracy_rate = accuracy_rates return sup_test_loss, semisup_test_loss, sup_accuracy_rate, semisup_accuracy_rate
class ArcfaceTrainer: def __init__(self, train_loader, valid_loader, test_loader, model, margin_penalty, train_loss_fn, test_loss_fn, sim_fn, device): self.train_loader = train_loader self.val_loader = valid_loader self.test_loader = test_loader self.model = model self.test_model = TripletNet(model) if torch.cuda.device_count() > 1: self.model = nn.DataParallel(self.model) self.test_model = nn.DataParallel(self.test_model) self.margin_penalty = margin_penalty self.train_loss_fn = train_loss_fn #ロス関数 self.test_loss_fn = test_loss_fn #ロス関数 self.sim_fn = sim_fn #類似度関数 self.device = device def fit(self, lr, n_epochs, log_interval, save_epoch_interval, start_epoch=0, outdir="../result/checkpoint/", data_dirname=None): """ Loaders, model, loss function and metrics should work together for a given task, i.e. The model should be able to process data output of loaders, loss function should process target output of loaders and outputs from the model Examples: Classification: batch loader, classification model, NLL loss, accuracy metric Siamese network: Siamese loader, siamese model, contrastive loss Online triplet learning: batch loader, embedding model, online triplet loss """ optimizer = optim.Adam([{ 'params': self.model.parameters() }, { 'params': self.margin_penalty.parameters() }], lr=lr) scheduler = lr_scheduler.StepLR(optimizer, 8, gamma=0.1, last_epoch=-1) for epoch in range(0, start_epoch): scheduler.step() for epoch in range(start_epoch, n_epochs): scheduler.step() # Train stage train_loss = self.train_epoch(optimizer, log_interval) message = 'Epoch: {}/{}\n\tTrain set: Average loss: {:.4f}'.format( epoch + 1, n_epochs, train_loss) # Validation stage val_loss, val_acc_rate = self.validation_epoch() val_loss /= len(self.val_loader) message += '\n\tValidation set: Average loss: {:.4f}'.format( val_loss) message += '\n\t Accuracy rate: {:.2f}%'.format( val_acc_rate) # Test stage test_loss, test_acc_rate = self.test_epoch() message += '\n\tTest set: Average loss: {:.4f}'.format(test_loss) message += '\n\t Accuracy rate: {:.2f}%'.format( test_acc_rate) logging.info(message + "\n") if data_dirname is not None and (epoch + 1) % save_epoch_interval == 0: if torch.cuda.device_count() > 1: num_out = self.model.module.num_out torch.save( self.model.module.embedding_net.state_dict(), f"{outdir}{data_dirname}_embeddingNet_out{num_out}_epoch{epoch}.pth" ) torch.save( self.model.module.state_dict(), f"{outdir}{data_dirname}_model_out{num_out}_epoch{epoch}.pth" ) else: num_out = self.model.num_out torch.save( self.model.embedding_net.state_dict(), f"{outdir}{data_dirname}_embeddingNet_out{num_out}_epoch{epoch}.pth" ) torch.save( self.model.state_dict(), f"{outdir}{data_dirname}_model_out{num_out}_epoch{epoch}.pth" ) train_loss = train_loss if float(train_loss) != 0.0 else 10000.0 return train_loss def train_epoch(self, optimizer, log_interval): self.model.train() losses = [] total_loss = 0 for batch_idx, (data, target) in enumerate(self.train_loader): data = data.to(self.device) target = target.to(self.device).long() optimizer.zero_grad() outputs = self.model(data) outputs = self.margin_penalty(outputs, target) loss_outputs = self.train_loss_fn(outputs, target) loss = loss_outputs[0] if type(loss_outputs) in ( tuple, list) else loss_outputs losses.append(loss.item()) total_loss += loss.item() loss.backward() optimizer.step() if batch_idx % log_interval == 0: message = 'Train: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( batch_idx * len(data), len(self.train_loader.dataset), 100. * batch_idx / len(self.train_loader), np.mean(losses)) logging.info(message) losses = [] total_loss /= (batch_idx + 1) return total_loss def validation_epoch(self): with torch.no_grad(): self.test_model.eval() val_loss = 0 n_true = 0 for batch_idx, (data, _) in enumerate(self.val_loader): if not type(data) in (tuple, list): data = (data, ) data = tuple(d.to(self.device) for d in data) outputs = self.test_model(*data) if type(outputs) not in (tuple, list): outputs = (outputs, ) loss_inputs = outputs loss_outputs = self.test_loss_fn(*loss_inputs) loss = loss_outputs[0] if type(loss_outputs) in ( tuple, list) else loss_outputs val_loss += loss.item() pos_dist, neg_dist = self.sim_fn(*loss_inputs) for i in range(len(pos_dist)): n_true += 1 if pos_dist[i] < neg_dist[i] else 0 accuracy_rate = (n_true / len(self.val_loader.dataset)) * 100 return val_loss, accuracy_rate def test_epoch(self): with torch.no_grad(): self.test_model.eval() test_loss = 0 n_true = 0 for batch_idx, (data, _) in enumerate(self.test_loader): if not type(data) in (tuple, list): data = (data, ) data = tuple(d.to(self.device) for d in data) outputs = self.test_model(*data) if type(outputs) not in (tuple, list): outputs = (outputs, ) loss_inputs = outputs loss_outputs = self.test_loss_fn(*loss_inputs) loss = loss_outputs[0] if type(loss_outputs) in ( tuple, list) else loss_outputs test_loss += loss.item() pos_dist, neg_dist = self.sim_fn(*loss_inputs) for i in range(len(pos_dist)): n_true += 1 if pos_dist[i] < neg_dist[i] else 0 accuracy_rate = (n_true / len(self.test_loader.dataset)) * 100 return test_loss, accuracy_rate