예제 #1
0
파일: model.py 프로젝트: ck-amrahd/birds
    def train(self,
              train_image_indices,
              batch_size,
              num_epochs=50,
              train_method='normal',
              lambda_1=0,
              lambda_2=0,
              start_from_pretrained_model=True,
              learning_rate=0.01,
              optimizer='SGD'):

        if os.path.exists(self.checkpoint_path):
            os.remove(self.checkpoint_path)

        model = self.initialize_model(
            start_from_pretrained_model=start_from_pretrained_model)

        model = model.to(self.device)
        criterion = nn.CrossEntropyLoss()

        if optimizer == 'SGD':
            optimizer = optim.SGD(model.parameters(),
                                  lr=learning_rate,
                                  momentum=0.9,
                                  weight_decay=5e-4)

        elif optimizer == 'Adam':
            optimizer = optim.Adam(model.parameters(),
                                   lr=learning_rate,
                                   weight_decay=5e-4)

        else:
            optimizer = optim.SGD(model.parameters(),
                                  lr=learning_rate,
                                  momentum=0.9,
                                  weight_decay=5e-4)

        train_batch_loader = BatchLoader(self.train_folder_path,
                                         train_image_indices)

        n_images = len(train_image_indices)
        if n_images % batch_size == 0:
            num_batches = n_images // batch_size
        else:
            num_batches = (n_images // batch_size) + 1

        penalty_inside_list = []
        penalty_outside_list = []
        train_acc_list = []
        train_loss_list = []
        val_loss_list = []
        val_acc_list = []
        best_acc = 0.0

        for epoch in range(num_epochs):
            model.train()
            train_batch_loader.reset()
            print('Epoch: {}/{}'.format(epoch + 1, num_epochs))
            print('-' * 50)

            train_correct = 0.0
            train_loss = 0.0
            penalty_inside = 0.0
            penalty_outside = 0.0

            for batch in range(num_batches):
                batch_indices = train_batch_loader.get_batch_indices(
                    batch_size)
                inputs = self.x_train[batch_indices]
                labels = self.y_train[batch_indices]
                inputs, labels = inputs.to(self.device), labels.to(self.device)

                if train_method == 'bbox':
                    inputs.requires_grad_()
                    outputs = model(inputs)
                    preds = torch.argmax(outputs, dim=1)

                    # cross entropy loss
                    loss = criterion(outputs, labels)
                    input_gradient = torch.autograd.grad(loss,
                                                         inputs,
                                                         create_graph=True)[0]
                    penalty_inside_box, penalty_outside_box = self.calculate_penalty_box(
                        batch_indices, input_gradient)
                    new_loss = loss + lambda_1 * penalty_inside_box + lambda_2 * penalty_outside_box
                    optimizer.zero_grad()
                    new_loss.backward()
                    optimizer.step()

                else:
                    outputs = model(inputs)
                    preds = torch.argmax(outputs, dim=1)

                    # cross entropy loss
                    loss = criterion(outputs, labels)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                    penalty_inside_box = torch.tensor(0).to(self.device)
                    penalty_outside_box = torch.tensor(0).to(self.device)

                train_loss += loss.item()
                train_correct += torch.sum(preds == labels).float().item()
                penalty_inside += penalty_inside_box.item() * lambda_1
                penalty_outside += penalty_outside_box.item() * lambda_2

            train_loss = train_loss / self.train_dataset_length
            train_loss_list.append(train_loss)
            train_acc = (train_correct / self.train_dataset_length) * 100.0
            train_acc_list.append(train_acc)
            penalty_inside = penalty_inside / self.train_dataset_length
            penalty_outside = penalty_outside / self.train_dataset_length
            penalty_inside_list.append(penalty_inside)
            penalty_outside_list.append(penalty_outside)

            print('Train Loss: {:.4f} Acc: {:.4f} % '.format(
                train_loss, train_acc))
            print(f'Penalty Inside Box: {round(penalty_inside, 4)}')
            print(f'Penalty Outside Box: {round(penalty_outside, 4)}')

            # validate after each epoch
            val_correct = 0.0
            val_loss = 0.0
            model.eval()
            with torch.no_grad():
                for inputs_val, labels_val in self.val_loader:
                    inputs_val, labels_val = inputs_val.to(
                        self.device), labels_val.to(self.device)
                    outputs_val = model(inputs_val)
                    preds_val = torch.argmax(outputs_val, dim=1)
                    loss_test = criterion(outputs_val, labels_val)

                    val_loss += loss_test.item()
                    val_correct += torch.sum(
                        preds_val == labels_val).float().item()

            val_loss = val_loss / self.val_dataset_length
            val_loss_list.append(val_loss)
            val_acc = (val_correct / self.val_dataset_length) * 100.0
            val_acc_list.append(val_acc)
            print('Val Loss: {:.4f} Acc: {:.4f} % \n'.format(
                val_loss, val_acc))

            # save the best model
            if val_acc > best_acc:
                best_acc = val_acc
                model.state_dict()
                if os.path.exists(self.checkpoint_path):
                    os.remove(self.checkpoint_path)

                torch.save(model.state_dict(), self.checkpoint_path)

        return_dict = {
            'train_acc_list': train_acc_list,
            'train_loss_list': train_loss_list,
            'penalty_inside_list': penalty_inside_list,
            'penalty_outside_list': penalty_outside_list,
            'val_loss_list': val_loss_list,
            'val_acc_list': val_acc_list,
            'best_acc': best_acc
        }

        return return_dict