예제 #1
0
def validation(args, model, testloader, epoch, mode='test'):
    conf_matrix = torch.zeros(args.classes, args.classes).cuda()
    model.eval()
    criterion = nn.CrossEntropyLoss()

    metrics = Metrics('')
    metrics.reset()
    batch_idx = 0
    ucsd_correct_total = 0
    sars_correct_total = 0
    ucsd_test_total = 0
    sars_test_total = 0
    with torch.no_grad():
        for input_tensors in tqdm(testloader):
            batch_idx = batch_idx + 1
            input_data, target, site = input_tensors
            if args.cuda:
                input_data = input_data.cuda()
                target = target.cuda()

            ucsd_input = torch.from_numpy(np.array([]))
            new_input = torch.from_numpy(np.array([]))
            ucsd_label = torch.from_numpy(np.array([]))
            new_label = torch.from_numpy(np.array([]))

            for i in range(len(input_data)):
                if site[i] == 'ucsd':
                    if len(ucsd_input) == 0:
                        ucsd_input = input_data[i].unsqueeze(0)
                        ucsd_label = torch.from_numpy(np.array([target[i]]))
                    else:
                        ucsd_input = torch.cat((ucsd_input, input_data[i].unsqueeze(0)))
                        ucsd_label = torch.cat((ucsd_label, torch.from_numpy(np.array([target[i]]))))
                else:
                    if len(new_input) == 0:
                        new_input = input_data[i].unsqueeze(0)
                        new_label = torch.from_numpy(np.array([target[i]]))
                    else:
                        new_input = torch.cat((new_input, input_data[i].unsqueeze(0)))
                        new_label = torch.cat((new_label, torch.from_numpy(np.array([target[i]]))))

            if len(ucsd_input) > 1:
                ucsd_output, ucsd_features = model(ucsd_input, 'ucsd')
                ucsd_correct, ucsd_total, ucsd_acc = accuracy(ucsd_output, ucsd_label.cuda())
                ucsd_correct_total += ucsd_correct
                ucsd_test_total += ucsd_total

            if len(new_input) > 1:
                new_output, new_features = model(new_input, 'ucsd')
                sars_correct, sars_total, sars_acc = accuracy(new_output, new_label.cuda())
                sars_correct_total += sars_correct
                sars_test_total += sars_total

            if len(ucsd_input) > 1 and len(new_input) > 1:
                output = torch.cat((ucsd_output, new_output))
                labels = torch.cat((ucsd_label, new_label)).cuda()
                features = torch.cat((ucsd_features, new_features))
            elif len(ucsd_input) > 1 and len(new_input) < 2:
                output = ucsd_output
                labels = ucsd_label.cuda()
            else:
                output = new_output
                labels = new_label.cuda()

            loss = criterion(output, labels)

            preds = torch.argmax(output, dim=1)
            for t, p in zip(target.view(-1), preds.view(-1)):
                conf_matrix[t.long(), p.long()] += 1

            correct, total, acc = accuracy(output, labels)

            # top k acc
            top1_correct = top_k_acc(output, labels, k=1)
            top3_correct = top_k_acc(output, labels, k=2)

            metrics.update({'correct': correct, 'total': total, 'loss': loss.item(), 'accuracy': acc,
                            'top1_correct': top1_correct, 'top3_correct': top3_correct})

    print_summary(args, epoch, metrics, mode="test")

    return metrics, conf_matrix, ucsd_correct_total, sars_correct_total, ucsd_test_total, sars_test_total
예제 #2
0
def train(args, model, trainloader, optimizer, epoch):
    model.train()
    criterion = nn.CrossEntropyLoss(reduction='elementwise_mean')

    metrics = Metrics('')
    metrics.reset()
    batch_idx = 0
    for input_tensors in tqdm(trainloader):
        batch_idx = batch_idx + 1
        optimizer.zero_grad()
        input_data, target, site = input_tensors
        if args.cuda:
            input_data = input_data.cuda()
            target = target.cuda()

        ucsd_input = torch.from_numpy(np.array([]))
        new_input = torch.from_numpy(np.array([]))
        ucsd_label = torch.from_numpy(np.array([]))
        new_label = torch.from_numpy(np.array([]))

        for i in range(len(input_data)):
            if site[i] == 'ucsd':
                if len(ucsd_input) == 0:
                    ucsd_input = input_data[i].unsqueeze(0)
                    ucsd_label = torch.from_numpy(np.array([target[i]]))
                else:
                    ucsd_input = torch.cat((ucsd_input, input_data[i].unsqueeze(0)))
                    ucsd_label = torch.cat((ucsd_label, torch.from_numpy(np.array([target[i]]))))
            else:
                if len(new_input) == 0:
                    new_input = input_data[i].unsqueeze(0)
                    new_label = torch.from_numpy(np.array([target[i]]))
                else:
                    new_input = torch.cat((new_input, input_data[i].unsqueeze(0)))
                    new_label = torch.cat((new_label, torch.from_numpy(np.array([target[i]]))))

        if len(ucsd_input) > 1:
            ucsd_output, ucsd_features = model(ucsd_input, 'ucsd')
        if len(new_input) > 1:
            new_output, new_features = model(new_input, 'ucsd')

        if len(ucsd_input) > 1 and len(new_input) > 1:
            output = torch.cat((ucsd_output, new_output))
            labels = torch.cat((ucsd_label, new_label)).cuda()
            features = torch.cat((ucsd_features, new_features))
        elif len(ucsd_input) > 1 and len(new_input) < 2:
            output = ucsd_output
            labels = ucsd_label.cuda()
            features = ucsd_features
        else:
            output = new_output
            labels = new_label.cuda()
            features = new_features

        if len(output) != len(labels):
            continue

        if len(features) == 32 and args.cont:
            temperature = 0.05
            cont_loss_func = losses.NTXentLoss(temperature)
            cont_loss = cont_loss_func(features, labels)
            loss = criterion(output, labels) + cont_loss
        else:
            loss = criterion(output, labels)

        loss.backward()

        nn.utils.clip_grad_norm_(model.parameters(), 2.0)
        optimizer.step()
        correct, total, acc = accuracy(output, labels)
        top1_correct = top_k_acc(output, labels, k=1)
        top3_correct = top_k_acc(output, labels, k=2)

        metrics.update({'correct': correct, 'total': total, 'loss': loss.item(), 'accuracy': acc,
                        'top1_correct': top1_correct, 'top3_correct': top3_correct})

    print_summary(args, epoch, metrics, mode="train")
    return metrics
예제 #3
0
print("Building model...")
model = Sequential()
model.add(Embedding(input_dim=nb_words+1, output_dim=outputDim, weights=[embedding_matrix], input_length=INPUT_LENGTH, trainable=True))
model.add(Convolution1D(FILTER_LENGTH, NB_FILTER, activation="relu", border_mode='valid'))
model.add(MaxPooling1D((21)))
model.add(Flatten())
model.add(Dense(5, activation="softmax"))
model.compile(loss="categorical_crossentropy", 
              optimizer="adam", 
              metrics=['accuracy'])



print('Train...')
#model.summary()

checkpoint = ModelCheckpoint(CALLBACK_PATH, monitor='val_acc', verbose=1, save_best_only=True, mode='max')
callbacks_list = [checkpoint]
model.fit(x_train, y_train, batch_size=BATCH_SIZE, nb_epoch=NB_EPOCH, callbacks=callbacks_list, 
          validation_data=(x_test, y_test))
score, accuracy = model.evaluate(x_test, y_test, batch_size=BATCH_SIZE)

y_pred = model.predict_classes(x_test, batch_size=BATCH_SIZE, verbose=1)
met.eval_MAE(test_data, y_pred, PRED_FILE=PRED_FILE, TRUE_FILE=TRUE_FILE)

print('Test score:', score)
print('Test accuracy:', accuracy)