Esempio n. 1
0
class CNN_DQN():
    def __init__(self, n_channel, n_action, gpu=False, lr=0.05) -> None:

        print(torch.cuda.is_available())
        print(torch.cuda.get_device_name(0))
        # self.device = torch.device("cpu")
        self.device = torch.device("cuda:0") if gpu else torch.device("cpu")

        self.criterion = torch.nn.MSELoss()
        self.model = CNNModel(n_channel, n_action).to(self.device)
        self.model_target = CNNModel(n_channel, n_action).to(self.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr)

    def update(self, state, target):
        prediction = self.model(torch.Tensor(state).to(self.device))
        loss = self.criterion(prediction,
                              Variable(torch.Tensor(target).to(self.device)))
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def predict(self, state):
        with torch.no_grad():
            return self.model(torch.Tensor(state).to(self.device))

    def predict_target(self, state):
        with torch.no_grad():
            return self.model_target(torch.Tensor(state).to(self.device))

    def copy_target(self):
        self.model_target.load_state_dict(self.model.state_dict())

    def replay(self, memory, replay_size, gamma):
        if len(memory) >= replay_size:
            replay_data = random.sample(memory, replay_size)
            states = []
            td_targets = []

            for state, action, next_state, reward, is_done in replay_data:
                states.append(state.tolist()[0])
                q_values = self.predict(state).tolist()[0]
                if is_done:
                    q_values[action] = reward
                else:
                    q_values_next = self.predict_target(next_state).detach()
                    q_values[action] = reward + gamma * torch.max(
                        q_values_next).item()
                td_targets.append(q_values)

            self.update(states, td_targets)

    def save(self):
        torch.save(self.model.state_dict(), 'cnn_dqn_model.pt')

    def load(self):
        if os.path.isfile('cnn_dqn_model.pt'):
            self.model.load_state_dict(torch.load('cnn_dqn_model.pt'))
            self.model.eval()
            print('Loaded model.')
        else:
            print('No model is found. New model initialized.')

    def gen_epsilon_greedy_policy(self, epsilon, n_action):
        def policy_function(state):
            if random.random() < epsilon:
                return random.randint(0, n_action - 1)
            else:
                q_values = self.predict(state)
                return torch.argmax(q_values).item()

        return policy_function
Esempio n. 2
0
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        model.optimizer.zero_grad()
        outputs = model(inputs)
        loss = model.loss_fn(outputs, labels)
        loss.backward()
        model.optimizer.step()
        if (i + 1) % 100 == 0:
            running_loss = loss.item()
            print("Epoch : ", epoch + 1, " , Step : ", i + 1, " , Loss : ",
                  running_loss)

model_path = '~/Models/CIFAR10/CNN/'
torch.save(model.state_dict(), model_path)

# Test the model
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        imgs, labels = data
        outputs = model(imgs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print("Test Accuracy: ", correct / total)
Esempio n. 3
0
def train_attention():
    param_class = get_param_class(args.data)
    run_id = args.data + '_' + str(uuid.uuid1())

    dataset_train = LeafDataset(
        data_path=args.dataset_path,
        genotype=param_class.genotype,
        inoculated=param_class.inoculated,
        dai=param_class.dai,
        test_size=param_class.test_size,
        signature_pre_clip=param_class.signature_pre_clip,
        signature_post_clip=param_class.signature_post_clip,
        max_num_balanced_inoculated=param_class.max_num_balanced_inoculated,
        num_samples_file=param_class.num_samples_file,
        split=args.split,
        mode='train',
        superpixel=True,
        bags=True,
        validation=False,
        transform=Augmentation())  # 50000
    dataset_test = LeafDataset(
        data_path=args.dataset_path,
        genotype=param_class.genotype,
        inoculated=param_class.inoculated,
        dai=param_class.dai,
        test_size=param_class.test_size,
        signature_pre_clip=param_class.signature_pre_clip,
        signature_post_clip=param_class.signature_post_clip,
        max_num_balanced_inoculated=param_class.
        max_num_balanced_inoculated,  # 50000
        num_samples_file=param_class.num_samples_file,
        split=args.split,
        mode="test",
        superpixel=True,
        bags=True,
        validation=False)

    print("Number of samples train", len(dataset_train))
    print("Number of samples test", len(dataset_test))
    dataloader = DataLoader(dataset_train,
                            batch_size=1,
                            shuffle=True,
                            num_workers=0)
    dataloader_test = DataLoader(dataset_test,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=0,
                                 drop_last=False)

    hyperparams = dataset_train.hyperparams
    print("Number of batches train", len(dataloader))
    print("Number of batches test", len(dataloader_test))

    # Original class counts train: 67578 264112
    # Original class counts test: 68093 263597
    hyperparams['num_classes'] = param_class.num_classes
    hyperparams['hidden_layer_size'] = param_class.hidden_layer_size
    hyperparams['num_heads'] = param_class.num_heads
    hyperparams['lr'] = args.lr
    hyperparams['num_epochs'] = args.num_epochs
    hyperparams['lr_scheduler_steps'] = args.lr_scheduler_steps
    '''
    model = SANNetwork(input_size=dataset_train.input_size,
                       num_classes=hyperparams['num_classes'],
                       hidden_layer_size=hyperparams['hidden_layer_size'],
                       dropout=0.9,
                       num_heads=hyperparams['num_heads'],
                       device="cuda")
    '''
    #model = ConvNetBarley(elu=False, avgpool=False, nll=False, num_classes=param_class.num_classes)
    model = CNNModel(num_classes=param_class.num_classes)

    num_epochs = hyperparams['num_epochs']
    optimizer = torch.optim.Adam(model.parameters(), lr=hyperparams['lr'])

    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, hyperparams['lr_scheduler_steps'], gamma=0.5, last_epoch=-1)
    num_params = sum(p.numel() for p in model.parameters())
    print("Number of parameters {}".format(num_params))
    print("Starting training for {} epochs".format(num_epochs))
    save_dir = "./uv_dataset/results_cv/"
    writer = SummaryWriter(log_dir=save_dir + run_id,
                           comment="_" + "_id_{}".format(run_id))

    #device = "cuda"
    #model.to(device)

    #balanced_loss_weight = torch.tensor([0.75, 0.25], device=device)  # torch.tensor([0.75, 0.25], device=device)
    balanced_loss_weight = torch.tensor([0.75, 0.25])
    crit = torch.nn.CrossEntropyLoss(weight=balanced_loss_weight)
    best_acc = 0
    for epoch in tqdm(range(num_epochs)):
        setproctitle("Gerste_MIL" + args.mode +
                     " | epoch {} of {}".format(epoch + 1, num_epochs))
        losses_per_batch = []
        correct = 0
        target, pred = [], []
        total = 0
        for i, (features, labels) in enumerate(dataloader):
            #labels = labels[2]
            features = features.float()  #.to(device)
            features = features.permute((1, 0, 2, 3, 4))
            labels = labels.long()  #.to(device)
            model.train()
            outputs, _ = model.forward(features)
            outputs = outputs.view(labels.shape[0], -1)
            labels = labels.view(-1)
            loss = crit(outputs, labels)
            optimizer.zero_grad()
            _, predicted = torch.max(outputs.data, 1)
            batch_pred, batch_target = getPredAndTarget(outputs, labels)
            target.append(batch_target)
            pred.append(batch_pred)
            # correct += balanced_accuracy(batch_target, batch_pred) * labels.size(0)  # mean
            # correct += (predicted == labels).sum().item()
            total += labels.size(0)
            loss.backward()
            optimizer.step()
            losses_per_batch.append(float(loss))
        mean_loss = np.mean(losses_per_batch)
        correct = balanced_accuracy(target, pred)
        writer.add_scalar('Loss/train', mean_loss, epoch)
        writer.add_scalar('Accuracy/train', 100 * correct, epoch)
        print("Epoch {}, mean loss per batch {}, train acc {}".format(
            epoch, mean_loss, 100 * correct))

        if (epoch + 1) % args.test_epoch == 0 or epoch + 1 == num_epochs:
            correct_test = 0
            target, pred = [], []
            total = 0
            model.eval()
            losses_per_batch = []
            attention_weights = []
            with torch.no_grad():
                for i, (features, labels) in enumerate(dataloader_test):
                    #labels = labels[2]
                    features = features.float()  #.to(device)
                    features = features.permute((1, 0, 2, 3, 4))
                    labels = labels.long()  #.to(device)
                    outputs, att = model.forward(features)
                    attention_weights.append(att.cpu().squeeze(0).numpy())
                    outputs = outputs.view(labels.shape[0], -1)
                    labels = labels.view(-1)
                    loss = crit(outputs, labels)
                    losses_per_batch.append(float(loss))
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    batch_pred, batch_target = getPredAndTarget(
                        outputs, labels)
                    target.append(batch_target)
                    pred.append(batch_pred)
                    # correct_test += balanced_accuracy(batch_target, batch_pred) * labels.size(0)
                    # correct += (predicted == labels).sum().item()
                mean_loss = np.mean(losses_per_batch)
                print(target, pred)
                correct_test = balanced_accuracy(target, pred)
                writer.add_scalar('Loss/test', mean_loss, epoch)
                #np.save('attention_weights.npy', attention_weights)
            print(
                'Accuracy, mean loss per batch of the network on the test samples: {} %, {}'
                .format(100 * correct_test, mean_loss))
            writer.add_scalar('Accuracy/test', 100 * correct_test, epoch)

            if (correct_test) >= best_acc:
                best_acc = (correct_test)
            model.train()

        scheduler.step()