Exemplo n.º 1
0
 def reinforce(self, rgb, nir, is_positive):
     ImageManager.reshape(rgb, nir)
     GT = ImageManager.get_GT(rgb)
     prediction = ImageManager.get_model_output_as_bw_image(self.model, rgb, nir)
     wrong_predictions = []
     for i in range(rgb.shape[0]):
         for j in range(rgb.shape[1]):
             if GT[i][j] and prediction[i][j] != is_positive:
                input_sample = ImageManager.get_input_sample(rgb[i][j], nir[i][j], is_positive)
                wrong_predictions.append(input_sample)
     dataset = MyDataset(data=wrong_predictions)
     dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
     self.model.load_state_dict(torch.load(self.load_path))
     criterion = NDVILoss.apply
     train_model(self.model, dataloader, criterion, max_epochs=1)
     torch.save(self.model.state_dict(), self.save_path)
Exemplo n.º 2
0
def main():
    # Parse arguments.
    args = parse_args()

    # Set device.
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    ROOT_DIR = ""
    imgDataset = MyDataset(args.data_csv,
                           ROOT_DIR,
                           transform=transforms.ToTensor())
    # Load dataset.
    train_data, test_data = train_test_split(imgDataset, test_size=0.2)
    pd.to_pickle(test_data, "test_data.pkl")
    del test_data
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=20,
                                               shuffle=True)

    print('data set')
    # Set a model.
    model = models.resnet18()
    model.train()
    model.fc = torch.nn.Linear(512, 3)
    model = model.to(device)

    print('model set')
    # Set loss function and optimization function.
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())
    #optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
    print('optimizer set')

    # Train and test.
    print('Train starts')
    for epoch in range(args.n_epoch):
        # Train and test a model.
        train_acc, train_loss = train(model, device, train_loader, criterion,
                                      optimizer)

        # Output score.
        if (epoch % args.test_interval == 0):
            pd.to_pickle(train_data, "train_data.pkl")
            del train_data

            test_data = pd.read_pickle("test_data.pkl")
            test_loader = torch.utils.data.DataLoader(test_data,
                                                      batch_size=20,
                                                      shuffle=True)
            del test_data
            test_acc, test_loss = test(model, device, test_loader, criterion)
            del test_loader

            stdout_temp = 'epoch: {:>3}, train acc: {:<8}, train loss: {:<8}, test acc: {:<8}, test loss: {:<8}'
            print(
                stdout_temp.format(epoch + 1, train_acc, train_loss, test_acc,
                                   test_loss))

            train_data = pd.read_pickle("train_data.pkl")
        else:
            stdout_temp = 'epoch: {:>3}, train acc: {:<8}, train loss: {:<8}'  #, test acc: {:<8}, test loss: {:<8}'
            print(stdout_temp.format(epoch + 1, train_acc,
                                     train_loss))  #, test_acc, test_loss))

        # Save a model checkpoint.
        if (epoch % args.save_model_interval == 0):
            model_ckpt_path = args.model_ckpt_path_temp.format(
                args.dataset_name, args.model_name, epoch + 1)
            torch.save(model.state_dict(), model_ckpt_path)
            print('Saved a model checkpoint at {}'.format(model_ckpt_path))
            print('')
Exemplo n.º 3
0
def main():
    # Parse arguments.
    args = parse_args()

    # Set device.
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Prepare dataset.
    np.random.seed(seed=0)
    image_dataframe = pd.read_csv(args.data_csv, engine='python', header=None)
    image_dataframe = image_dataframe.reindex(
        np.random.permutation(image_dataframe.index))
    test_num = int(len(image_dataframe) * 0.2)
    train_dataframe = image_dataframe[test_num:]
    test_dataframe = image_dataframe[:test_num]
    train_data = MyDataset(train_dataframe, transform=transforms.ToTensor())
    test_data = MyDataset(test_dataframe, transform=transforms.ToTensor())
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=20,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=20)

    print('data set')
    # Set a model.
    if args.model == 'resnet18':
        model = models.resnet18()
        model.fc = torch.nn.Linear(512, 3)
    elif args.model == 'samplenet':
        model = SampleNet()
    elif args.model == 'simplenet':
        model = SimpleNet()
    else:
        raise NotImplementedError()
    model.train()
    model = model.to(device)

    print('model set')
    # Set loss function and optimization function.
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    #optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
    print('optimizer set')

    # Train and test.
    print('Train starts')
    for epoch in range(args.n_epoch):
        # Train and test a model.
        train_acc, train_loss = train(model, device, train_loader, criterion,
                                      optimizer)

        # Output score.
        if (epoch % args.test_interval == 0):
            test_acc, test_loss = test(model, device, test_loader, criterion)

            stdout_temp = 'epoch: {:>3}, train acc: {:<8}, train loss: {:<8}, test acc: {:<8}, test loss: {:<8}'
            print(
                stdout_temp.format(epoch + 1, train_acc, train_loss, test_acc,
                                   test_loss))

        else:
            stdout_temp = 'epoch: {:>3}, train acc: {:<8}, train loss: {:<8}'  #, test acc: {:<8}, test loss: {:<8}'
            print(stdout_temp.format(epoch + 1, train_acc,
                                     train_loss))  #, test_acc, test_loss))

        # Save a model checkpoint.
        if (epoch % args.save_model_interval == 0
                or epoch + 1 == args.n_epoch):
            model_ckpt_path = args.model_ckpt_path_temp.format(
                args.dataset_name, args.model_name, epoch + 1)
            torch.save(model.state_dict(), model_ckpt_path)
            print('Saved a model checkpoint at {}'.format(model_ckpt_path))
            print('')