Ejemplo n.º 1
0
    def __init__(self, train_loader, valid_loader, test_loader, model,
                 margin_penalty, train_loss_fn, test_loss_fn, sim_fn, device):

        self.train_loader = train_loader
        self.val_loader = valid_loader
        self.test_loader = test_loader

        self.model = model
        self.test_model = TripletNet(model)
        if torch.cuda.device_count() > 1:
            self.model = nn.DataParallel(self.model)
            self.test_model = nn.DataParallel(self.test_model)

        self.margin_penalty = margin_penalty
        self.train_loss_fn = train_loss_fn  #ロス関数
        self.test_loss_fn = test_loss_fn  #ロス関数
        self.sim_fn = sim_fn  #類似度関数

        self.device = device
Ejemplo n.º 2
0
CLASSF_MODEL_NAME = "2dCNNL2_TripletLoss_classf.pt"
""" Load Training data"""
# Load training data to list
training, training_labels = load_glued_image(GLUED_IMAGE_PATH, NUM_CLASS, 0,
                                             IMAGE_SIZE)
# Load the negative set of images
negative, negative_labels = load_data(NEGATIVE_DIR,
                                      n_class=NUM_CLASS,
                                      cls=5,
                                      img_size=IMAGE_SIZE)
print("-" * 30)
print("Data loaded!!")
print("-" * 30)

# 3 models extracting the pattern feature
triplet_mdoel = TripletNet(CNNEmbeddingNetL2()).cuda()

# pattern_model = TripletNet(
#     torch.hub.load('pytorch/vision:v0.9.0', 'mobilenet_v2', pretrained=True)
# ).cuda()

criterion = TripletLoss(margin=MARGIN)

# optimizer = optim.Adam(triplet_model.parameters(), lr=LR)
optimizer = optim.SGD(triplet_mdoel.embedding_net.parameters(),
                      lr=LR,
                      momentum=0.9)
"""Training Phase"""
""" Stage 1: Train the Embedding Network"""
triplet_model = train_triplet(triplet_mdoel,
                              criterion,
for model_type in ["pattern", "color", "shape"]:

    # Load paths and create Pytorch dataset
    training_paths, training_labels = load_data(TRAINING_DIR, NUM_CLASS)
    training = TripletVaseDataset(VaseDataset(training_paths, training_labels, IMAGE_SIZE, CROP_SIZE, model_type))

    # Make data loaders for the clustered CNN
    training_loader = DataLoader(training, batch_size=TRAINING_BATCH_SIZE, shuffle=True)

    print("-" * 30)
    print(f"{model_type} data loaded!!")
    print("-" * 30)

    in_channel = 2 if model_type == "color" else 3

    model = TripletNet(CNNEmbeddingNetL2(in_channel, 128)).cuda()
    # pattern_model = TripletNet(
    #     torch.hub.load('pytorch/vision:v0.9.0', 'mobilenet_v2', pretrained=True)
    # ).cuda()

    """Training Phase"""
    print("-" * 30)
    print(f"Training {model_type} Model")
    print("-" * 30)

    criterion = TripletLoss(margin=MARGIN)

    # TODO: Can trial on ADAM optimizers
    optimizer = optim.SGD(model.embedding_net.parameters(), lr=LR, momentum=0.9)
    model = train_triplet(model, criterion, optimizer, training_loader, n_epoch=N_EPOCH)
    torch.save(model, os.path.join(MODEL_ROOT_DIR, f"{TRIPLET_MODEL_NAME}_{model_type}_v{TRIPLET_MODEL_VERSION}.pt"))
Ejemplo n.º 4
0
def main(args):
    assert args.save_interval % 10 == 0, "save_interval must be a multiple of 10"

    # prepare dirs
    os.makedirs(args.log_dir, exist_ok=True)
    os.makedirs(args.save_model, exist_ok=True)
    
    device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
    print("Device is", device)

    # img path loading
    with open("data/3d_data.pkl", mode='rb') as f:
        data_3d = pickle.load(f)
    train_path_list =  data_3d.train_pl
    val_path_list = data_3d.val_pl

    train_dataset = TripletDataset(transform=ImageTransform(), flist=train_path_list)
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)

    val_dataset = TripletDataset(transform=ImageTransform(), flist=val_path_list)
    val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)

    model = TripletNet()
    model.to(device)
    
    criterion = nn.MarginRankingLoss(margin=args.margin)

    # choose params to train
    update_params_name = []
    for name, _ in model.named_parameters():
        if 'layer4' in name:
            update_params_name.append(name)
        elif 'fc' in name:
            update_params_name.append(name)

    print("**-----** update params **-----**")
    print(update_params_name)
    print("**-----------------------------**")
    print()

    params_to_update = choose_update_params(update_params_name, model)

    # set optimizer
    optimizer = optim.SGD(params_to_update, lr=1e-4, momentum=0.9)

    # run epoch    
    log_writer = SummaryWriter(log_dir=args.log_dir)
    for epoch in range(args.num_epochs):
        print("-"*80)
        print('Epoch {}/{}'.format(epoch+1, args.num_epochs))

        epoch_loss, epoch_acc = [], []
        for inputs, labels in tqdm(train_dataloader):
            batch_loss, batch_acc = train_one_batch(inputs, labels, model, criterion, optimizer, device)
            epoch_loss.append(batch_loss.item())
            epoch_acc.append(batch_acc.item())
        
        epoch_loss = np.array(epoch_loss)
        epoch_acc = np.array(epoch_acc)
        print('[Loss: {:.4f}], [Acc: {:.4f}] \n'.format(np.mean(epoch_loss), np.mean(epoch_acc)))
        log_writer.add_scalar("train/loss", np.mean(epoch_loss), epoch+1)
        log_writer.add_scalar("train/acc", np.mean(epoch_acc), epoch+1)


        # validation
        if (epoch+1) % 10 == 0:
            print("Run Validation")
            epoch_loss, epoch_acc = [], []
            for inputs, labels in tqdm(val_dataloader):
                batch_loss, batch_acc = validation(inputs, labels, model, criterion, device)
                epoch_loss.append(batch_loss.item())
                epoch_acc.append(batch_acc.item())
            
            epoch_loss = np.array(epoch_loss)
            epoch_acc = np.array(epoch_acc)
            print('[Validation Loss: {:.4f}], [Validation Acc: {:.4f}]'.format(np.mean(epoch_loss), np.mean(epoch_acc)))
            log_writer.add_scalar("val/loss", np.mean(epoch_loss), epoch+1)
            log_writer.add_scalar("val/acc", np.mean(epoch_acc), epoch+1)

            # save model
            if (args.save_interval > 0) and ((epoch+1) % args.save_interval == 0):
                save_path = os.path.join(args.save_model, '{}_epoch_{:.1f}.pth'.format(epoch+1, np.mean(epoch_loss)))
                torch.save(model.state_dict(), save_path)

    log_writer.close()
Ejemplo n.º 5
0
class UIRTrainer:
    def __init__(self, sup_train_loader, semisup_train_loader, sup_valid_loader, semisup_valid_loader, sup_test_loader, semisup_test_loader, \
                model, margin_penalty, sup_train_loss_fn, semisup_train_loss_fn, test_loss_fn, sim_fn, device):

        self.sup_train_loader = sup_train_loader
        self.sup_val_loader = sup_valid_loader
        self.sup_test_loader = sup_test_loader
        self.semisup_train_loader = semisup_train_loader
        self.semisup_val_loader = semisup_valid_loader
        self.semisup_test_loader = semisup_test_loader

        self.model = model
        self.test_model = TripletNet(model)
        if torch.cuda.device_count() > 1:
            self.model = nn.DataParallel(self.model)
            self.test_model = nn.DataParallel(self.test_model)
        self.margin_penalty = margin_penalty
        self.sup_train_loss_fn = sup_train_loss_fn  #ロス関数
        self.semisup_train_loss_fn = semisup_train_loss_fn  #ロス関数
        self.test_loss_fn = test_loss_fn  #ロス関数
        self.sim_fn = sim_fn  #類似度関数

        self.device = device

    def fit(self,
            lr,
            n_epochs,
            log_interval,
            save_epoch_interval,
            start_epoch=0,
            outdir="../result/checkpoint/",
            data_dirname=None):
        """
        Loaders, model, loss function and metrics should work together for a given task,
        i.e. The model should be able to process data output of loaders,
        loss function should process target output of loaders and outputs from the model

        Examples: Classification: batch loader, classification model, NLL loss, accuracy metric
        Siamese network: Siamese loader, siamese model, contrastive loss
        Online triplet learning: batch loader, embedding model, online triplet loss
        """

        sup_optimizer = optim.Adam([{
            'params': self.model.parameters()
        }, {
            'params': self.margin_penalty.parameters()
        }],
                                   lr=lr)
        sup_scheduler = lr_scheduler.StepLR(sup_optimizer,
                                            8,
                                            gamma=0.1,
                                            last_epoch=-1)
        semisup_optimizer = optim.Adam(
            [{
                'params': self.model.parameters()
            }, {
                'params': self.margin_penalty.parameters()
            }],
            lr=lr)
        semisup_scheduler = lr_scheduler.StepLR(semisup_optimizer,
                                                8,
                                                gamma=0.1,
                                                last_epoch=-1)

        n_epochs *= 2  #教師あり学習と半教師あり学習の2回行うため
        if start_epoch != 0:
            embedding_model = torch.load(
                f"{outdir}{data_dirname}_embeddingNet_out{self.model.num_out}_epoch{start_epoch-1}.pth"
            )
            model = torch.load(
                f"{outdir}{data_dirname}_model_out{self.model.num_out}_epoch{start_epoch-1}.pth"
            )
            margin_penalty = torch.load(
                f"{outdir}{data_dirname}_marginPenalty_out{self.model.num_out}_epoch{start_epoch-1}.pth"
            )
            self.model.load_state_dict(model)
            self.margin_penalty.load_state_dict(margin_penalty)

        for epoch in range(0, start_epoch):
            if epoch < n_epochs / 2:
                sup_scheduler.step()
            else:
                semisup_scheduler.step()

        for epoch in range(start_epoch, n_epochs):
            # Train stage
            if epoch < n_epochs / 2:
                sup_scheduler.step()
                train_loss = self.sup_train_epoch(sup_optimizer, log_interval)
            else:
                semisup_scheduler.step()
                train_loss = self.semisup_train_epoch(semisup_optimizer,
                                                      log_interval)

            message = 'Epoch: {}/{}\n\tTrain set: Average loss: {:.4f}'.format(
                epoch + 1, n_epochs, train_loss)

            # Validation stage
            sup_val_loss, semisup_val_loss, sup_val_acc, semisup_val_acc = self.validation_epoch(
            )
            sup_val_loss /= len(self.sup_val_loader)
            semisup_val_loss /= len(self.semisup_val_loader)

            message += '\n\tValidation set: Average loss: labeled{:.6f}, unlabeled{:.6f}'.format(
                sup_val_loss, semisup_val_loss)
            message += '\n\t                Accuracy rate: labeled{:.6f}%, unlabeled{:.6f}%'.format(
                sup_val_acc, semisup_val_acc)

            # Test stage
            sup_test_loss, semisup_test_loss, sup_test_acc, semisup_test_acc = self.test_epoch(
            )

            message += '\n\tTest set: Average loss: labeled{:.6f}, unlabeled{:.6f}'.format(
                sup_test_loss, semisup_test_loss)
            message += '\n\t          Accuracy rate: labeled{:.6f}%, unlabeled{:.6f}%'.format(
                sup_test_acc, semisup_test_acc)

            logging.info(message + "\n")

            if data_dirname is not None and (epoch +
                                             1) % save_epoch_interval == 0:
                torch.save(
                    self.model.state_dict(),
                    f"{outdir}{data_dirname}_embeddingNet_out{self.model.num_out}_epoch%d.pth"
                    % epoch)
                torch.save(
                    self.model.state_dict(),
                    f"{outdir}{data_dirname}_model_out{self.model.num_out}_epoch%d.pth"
                    % epoch)
                torch.save(
                    self.margin_penalty.state_dict(),
                    f"{outdir}{data_dirname}_marginPenalty_out{self.model.num_out}_epoch%d.pth"
                    % epoch)
        train_loss = train_loss if float(train_loss) != 0.0 else 10000.0

        return train_loss

    def sup_train_epoch(self, optimizer, log_interval):

        self.model.train()
        losses = []
        total_loss = 0

        for batch_idx, (data, target) in enumerate(self.sup_train_loader):
            data = data.to(self.device)
            target = target.to(self.device).long()

            optimizer.zero_grad()
            outputs = self.model(data)
            outputs = self.margin_penalty(outputs, target)

            loss_outputs = self.sup_train_loss_fn(outputs, target)

            loss = loss_outputs[0] if type(loss_outputs) in (
                tuple, list) else loss_outputs
            losses.append(loss.item())
            total_loss += loss.item()
            loss.backward()
            optimizer.step()

            if batch_idx % log_interval == 0:
                message = 'Train: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    batch_idx * len(data), len(self.sup_train_loader.dataset),
                    100. * batch_idx / len(self.sup_train_loader),
                    np.mean(losses))

                logging.info(message)
                losses = []

        total_loss /= (batch_idx + 1)

        return total_loss

    def semisup_train_epoch(self, optimizer, log_interval):

        self.model.train()
        labeled_losses = []
        unlabeled_losses = []
        losses = []
        total_loss = 0

        for batch_idx, ((labeled_data, labeled_target),
                        (unlabeled_data, unlabeled_target)) in enumerate(
                            zip(self.sup_train_loader,
                                self.semisup_train_loader)):
            labeled_data, unlabeled_data = labeled_data.to(
                self.device), unlabeled_data.to(self.device)
            labeled_target, unlabeled_target = labeled_target.to(
                self.device).long(), unlabeled_target.to(self.device).long()

            optimizer.zero_grad()

            labeled_outputs, unlabeled_outputs = self.model(
                labeled_data), self.model(unlabeled_data)
            labeled_outputs, unlabeled_outputs \
                = self.margin_penalty(labeled_outputs, labeled_target), self.margin_penalty(unlabeled_outputs, unlabeled_target)

            labeled_loss_outputs, unlabeled_loss_outputs \
                = self.sup_train_loss_fn(labeled_outputs, labeled_target), self.semisup_train_loss_fn(unlabeled_outputs)
            labeled_loss = labeled_loss_outputs[0] if type(
                labeled_loss_outputs) in (tuple,
                                          list) else labeled_loss_outputs
            unlabeled_loss = unlabeled_loss_outputs[0] if type(
                unlabeled_loss_outputs) in (tuple,
                                            list) else unlabeled_loss_outputs
            labeled_losses.append(labeled_loss.item())
            unlabeled_losses.append(unlabeled_loss.item())
            loss = labeled_loss + unlabeled_loss
            total_loss += loss
            loss.backward()
            optimizer.step()

            if batch_idx % log_interval == 0:
                message = 'Train: [{}/{}, {}/{} ({:.0f}%)]\tLoss: labeled{:.6f}, unlabeled{:.6f}'.format(
                    batch_idx * len(labeled_data),
                    len(self.sup_train_loader.dataset),
                    batch_idx * len(unlabeled_data),
                    len(self.semisup_train_loader.dataset),
                    100. * batch_idx / len(self.sup_train_loader),
                    np.mean(labeled_losses), np.mean(unlabeled_losses))

                logging.info(message)
                losses = []
                labeled_losses = []
                unlabeled_losses = []

        total_loss /= (batch_idx + 1)

        return total_loss

    def validation_epoch(self):
        with torch.no_grad():

            self.test_model.eval()

            accuracy_rates = list()
            val_losses = list()
            for val_loader in [self.sup_val_loader, self.semisup_val_loader]:
                val_loss = 0
                n_true = 0
                for batch_idx, (data, _) in enumerate(val_loader):

                    if not type(data) in (tuple, list):
                        data = (data, )

                    data = tuple(d.to(self.device) for d in data)

                    outputs = self.test_model(*data)

                    if type(outputs) not in (tuple, list):
                        outputs = (outputs, )
                    loss_inputs = outputs

                    loss_outputs = self.test_loss_fn(*loss_inputs)
                    loss = loss_outputs[0] if type(loss_outputs) in (
                        tuple, list) else loss_outputs
                    val_loss += loss.item()

                    pos_dist, neg_dist = self.sim_fn(*loss_inputs)

                    for i in range(len(pos_dist)):
                        n_true += 1 if pos_dist[i] < neg_dist[i] else 0

                val_losses.append(val_loss)
                accuracy_rates.append((n_true / len(val_loader.dataset)) * 100)
            sup_val_loss, semisup_val_loss = val_losses
            sup_accuracy_rate, semisup_accuracy_rate = accuracy_rates
        return sup_val_loss, semisup_val_loss, sup_accuracy_rate, semisup_accuracy_rate

    def test_epoch(self):

        with torch.no_grad():

            self.test_model.eval()

            accuracy_rates = list()
            test_losses = list()
            for test_loader in [
                    self.sup_test_loader, self.semisup_test_loader
            ]:
                test_loss = 0
                n_true = 0
                for batch_idx, (data, _) in enumerate(test_loader):

                    if not type(data) in (tuple, list):
                        data = (data, )

                    data = tuple(d.to(self.device) for d in data)

                    outputs = self.test_model(*data)

                    if type(outputs) not in (tuple, list):
                        outputs = (outputs, )
                    loss_inputs = outputs

                    loss_outputs = self.test_loss_fn(*loss_inputs)
                    loss = loss_outputs[0] if type(loss_outputs) in (
                        tuple, list) else loss_outputs
                    test_loss += loss.item()

                    pos_dist, neg_dist = self.sim_fn(*loss_inputs)

                    for i in range(len(pos_dist)):
                        n_true += 1 if pos_dist[i] < neg_dist[i] else 0

                test_losses.append(test_loss)
                accuracy_rates.append(
                    (n_true / len(test_loader.dataset)) * 100)
            sup_test_loss, semisup_test_loss = test_losses
            sup_accuracy_rate, semisup_accuracy_rate = accuracy_rates
        return sup_test_loss, semisup_test_loss, sup_accuracy_rate, semisup_accuracy_rate
Ejemplo n.º 6
0
class ArcfaceTrainer:
    def __init__(self, train_loader, valid_loader, test_loader, model,
                 margin_penalty, train_loss_fn, test_loss_fn, sim_fn, device):

        self.train_loader = train_loader
        self.val_loader = valid_loader
        self.test_loader = test_loader

        self.model = model
        self.test_model = TripletNet(model)
        if torch.cuda.device_count() > 1:
            self.model = nn.DataParallel(self.model)
            self.test_model = nn.DataParallel(self.test_model)

        self.margin_penalty = margin_penalty
        self.train_loss_fn = train_loss_fn  #ロス関数
        self.test_loss_fn = test_loss_fn  #ロス関数
        self.sim_fn = sim_fn  #類似度関数

        self.device = device

    def fit(self,
            lr,
            n_epochs,
            log_interval,
            save_epoch_interval,
            start_epoch=0,
            outdir="../result/checkpoint/",
            data_dirname=None):
        """
        Loaders, model, loss function and metrics should work together for a given task,
        i.e. The model should be able to process data output of loaders,
        loss function should process target output of loaders and outputs from the model

        Examples: Classification: batch loader, classification model, NLL loss, accuracy metric
        Siamese network: Siamese loader, siamese model, contrastive loss
        Online triplet learning: batch loader, embedding model, online triplet loss
        """

        optimizer = optim.Adam([{
            'params': self.model.parameters()
        }, {
            'params': self.margin_penalty.parameters()
        }],
                               lr=lr)
        scheduler = lr_scheduler.StepLR(optimizer, 8, gamma=0.1, last_epoch=-1)

        for epoch in range(0, start_epoch):
            scheduler.step()

        for epoch in range(start_epoch, n_epochs):
            scheduler.step()

            # Train stage
            train_loss = self.train_epoch(optimizer, log_interval)

            message = 'Epoch: {}/{}\n\tTrain set: Average loss: {:.4f}'.format(
                epoch + 1, n_epochs, train_loss)

            # Validation stage
            val_loss, val_acc_rate = self.validation_epoch()
            val_loss /= len(self.val_loader)

            message += '\n\tValidation set: Average loss: {:.4f}'.format(
                val_loss)
            message += '\n\t                Accuracy rate: {:.2f}%'.format(
                val_acc_rate)

            # Test stage
            test_loss, test_acc_rate = self.test_epoch()

            message += '\n\tTest set: Average loss: {:.4f}'.format(test_loss)
            message += '\n\t          Accuracy rate: {:.2f}%'.format(
                test_acc_rate)

            logging.info(message + "\n")

            if data_dirname is not None and (epoch +
                                             1) % save_epoch_interval == 0:
                if torch.cuda.device_count() > 1:
                    num_out = self.model.module.num_out
                    torch.save(
                        self.model.module.embedding_net.state_dict(),
                        f"{outdir}{data_dirname}_embeddingNet_out{num_out}_epoch{epoch}.pth"
                    )
                    torch.save(
                        self.model.module.state_dict(),
                        f"{outdir}{data_dirname}_model_out{num_out}_epoch{epoch}.pth"
                    )
                else:
                    num_out = self.model.num_out
                    torch.save(
                        self.model.embedding_net.state_dict(),
                        f"{outdir}{data_dirname}_embeddingNet_out{num_out}_epoch{epoch}.pth"
                    )
                    torch.save(
                        self.model.state_dict(),
                        f"{outdir}{data_dirname}_model_out{num_out}_epoch{epoch}.pth"
                    )

        train_loss = train_loss if float(train_loss) != 0.0 else 10000.0

        return train_loss

    def train_epoch(self, optimizer, log_interval):

        self.model.train()
        losses = []
        total_loss = 0

        for batch_idx, (data, target) in enumerate(self.train_loader):

            data = data.to(self.device)
            target = target.to(self.device).long()

            optimizer.zero_grad()
            outputs = self.model(data)
            outputs = self.margin_penalty(outputs, target)
            loss_outputs = self.train_loss_fn(outputs, target)

            loss = loss_outputs[0] if type(loss_outputs) in (
                tuple, list) else loss_outputs
            losses.append(loss.item())
            total_loss += loss.item()
            loss.backward()
            optimizer.step()

            if batch_idx % log_interval == 0:
                message = 'Train: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    batch_idx * len(data), len(self.train_loader.dataset),
                    100. * batch_idx / len(self.train_loader), np.mean(losses))

                logging.info(message)
                losses = []

        total_loss /= (batch_idx + 1)

        return total_loss

    def validation_epoch(self):
        with torch.no_grad():

            self.test_model.eval()
            val_loss = 0

            n_true = 0
            for batch_idx, (data, _) in enumerate(self.val_loader):

                if not type(data) in (tuple, list):
                    data = (data, )

                data = tuple(d.to(self.device) for d in data)

                outputs = self.test_model(*data)

                if type(outputs) not in (tuple, list):
                    outputs = (outputs, )
                loss_inputs = outputs

                loss_outputs = self.test_loss_fn(*loss_inputs)
                loss = loss_outputs[0] if type(loss_outputs) in (
                    tuple, list) else loss_outputs
                val_loss += loss.item()

                pos_dist, neg_dist = self.sim_fn(*loss_inputs)

                for i in range(len(pos_dist)):
                    n_true += 1 if pos_dist[i] < neg_dist[i] else 0

            accuracy_rate = (n_true / len(self.val_loader.dataset)) * 100

        return val_loss, accuracy_rate

    def test_epoch(self):
        with torch.no_grad():

            self.test_model.eval()

            test_loss = 0
            n_true = 0
            for batch_idx, (data, _) in enumerate(self.test_loader):
                if not type(data) in (tuple, list):
                    data = (data, )

                data = tuple(d.to(self.device) for d in data)

                outputs = self.test_model(*data)

                if type(outputs) not in (tuple, list):
                    outputs = (outputs, )
                loss_inputs = outputs

                loss_outputs = self.test_loss_fn(*loss_inputs)
                loss = loss_outputs[0] if type(loss_outputs) in (
                    tuple, list) else loss_outputs
                test_loss += loss.item()

                pos_dist, neg_dist = self.sim_fn(*loss_inputs)

                for i in range(len(pos_dist)):
                    n_true += 1 if pos_dist[i] < neg_dist[i] else 0

            accuracy_rate = (n_true / len(self.test_loader.dataset)) * 100

        return test_loss, accuracy_rate