コード例 #1
0
    def evaluate_similarity(self, question_file, threshold=0.8, prefix=False):
        question1 = []
        question2 = []
        gold_scores = []

        with open(question_file) as file:
            for line in file:
                line = line.rstrip()
                q1, q2, gold = line.split('\t')
                question1.append(q1)
                question2.append(q2)
                gold_scores.append(int(gold))

        question_vectors_1 = [
            self.question_to_vector(normalize_questions(x), prefix)
            for x in question1
        ]
        question_vectors_2 = [
            self.question_to_vector(normalize_questions(x), prefix)
            for x in question2
        ]

        scores = []
        for i in range(len(question_vectors_1)):
            if i % 10 == 0:
                string = "<" + str(datetime.datetime.now(
                )) + ">  " + 'Evaluating Question Pairs: ' + str(
                    int(100 * ((i + 10) / len(question_vectors_1)))) + '%'
                print(string, end="\r")

            score = calculate_cosine_simil(question_vectors_1[i],
                                           question_vectors_2[i])
            if score > threshold:
                scores.append(1)
            else:
                scores.append(0)
        print()
        result = sklearn.metrics.log_loss(gold_scores, scores)
        TP, FP, TN, FN = perf_measure(gold_scores, scores)
        acc = np.sum(
            np.array(gold_scores) == np.array(scores)) / len(gold_scores)

        print('Log Loss: ' + str(result))
        print('Acc: ' + str(acc))
        print('TP: ' + str(TP) + '\tFP: ' + str(FP) + '\tTN: ' + str(TN) +
              '\tFN: ' + str(FN))
        print(scores)
        print(gold_scores)
        return result, acc, TP, FP, TN, FN
コード例 #2
0
    def eval_model(self):
        global TEXT_EVAL

        results = []
        for data in [self.traindata, self.validdata, self.testdata]:
            predictionclasses = []
            for dataslice,_ in self._sample_pairs(data, len(data['classes']), shuffle=False, once=True):
                predictionclasses += list(self.model.predict(dataslice))
            #print(predictionclasses)
            scores = []
            for p in predictionclasses:
                if p[0]>p[1]:
                    scores.append(0)
                else:
                    scores.append(1)


            #prediction = np.dot(np.array(predictionclasses),np.arange(self.c['num_classes']))
            gold_scores = data['labels']


            result = sklearn.metrics.log_loss(gold_scores, scores)
            TP, FP, TN, FN = perf_measure(gold_scores, scores)
            acc = np.sum(np.array(gold_scores) == np.array(scores)) / len(gold_scores)

            print('Log Loss: ' + str(result))
            print('Acc: ' + str(acc))
            print('TP: ' + str(TP) + '\tFP: ' + str(FP) + '\tTN: ' + str(TN) + '\tFN: ' + str(FN))
            #print(scores)
            #print(gold_scores)
            results.append([result, acc, TP, FP, TN, FN])


            #result=pearsonr(prediction, goldlabels)[0]
            #results.append(round(result,4))
        print('TEST:' + str(results[0]))
        print('DEV:' + str(results[1]))
        print('TEST:' + str(results[2]))
        #print(results)

        return results
コード例 #3
0
def evaluate(number, step):
    """
	Function to evaluate a model.
	It returns (on standard output) precision, recall, ... (see utils module).
	It also gives the ROC and AUC for the model, if sklearn is imported.

	Parameters
	----------
	numbers:	int
			Number of the model. All models are stored in src/logs/save/[number]/

	step:		int
			Step iteration. Example : step=4500, model4500 will be evaluated.

	Returns
	-------		
	"""

    # Build the model
    network = model.Model(
        n_inputs=FLAGS.note_size,
        n_outputs=FLAGS.output_size,
        n_hidden=FLAGS.n_hidden,
        n_step=FLAGS.future_size + FLAGS.past_size + 1,
        num_layers=FLAGS.num_layers,
        lr=FLAGS.learning_rate,
        momentum=FLAGS.momentum,
        target_note=FLAGS.past_size - 1,
    )
    # Create session
    session = tf.Session()
    session.run(tf.global_variables_initializer())
    # Restore the model
    saver = tf.train.Saver()
    print number
    print os.path.join(utils.getRoot(), FLAGS.save_dir, str(number),
                       FLAGS.save + str(step))
    saver.restore(
        session,
        os.path.join(utils.getRoot(), FLAGS.save_dir, str(number),
                     FLAGS.save + str(step)))
    # Build dataset
    parameters = {
        "directory": FLAGS.validation_dir,
        "batch_size": FLAGS.batch_size,
        "past": FLAGS.past_size,
        "future": FLAGS.future_size,
        "note_size": FLAGS.note_size,
        "output_size": FLAGS.output_size,
        "epsilon": FLAGS.epsilon_min,
        "meaning": True
    }
    validset = ValidationSet(**parameters)
    test_x, test_y = validset.next_batch(40 * FLAGS.batch_size)
    test_x_p, test_y_p = validset.next_batch(40 * FLAGS.batch_size,
                                             for_future=False)
    del validset
    # Evaluate model
    pred, real = network.test(session, test_x, test_y)
    #pred_tmp, real_tmp = network.test_past(session, test_x, test_y)
    #pred = np.append(pred, pred_tmp, axis=0)
    #real = np.append(real, real_tmp, axis=0)
    # Get metrics and print them
    dict_acc = utils.metrics(utils.perf_measure(pred, real))
    for x in dict_acc:
        print x
    # Get values for ROC
    test, score = network.roc(session, test_x, test_y)
    # Get ROC curves and AUCs
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    for i in range(FLAGS.output_size):
        fpr[i], tpr[i], _ = roc_curve(score[:, i], test[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
    # Plot all ROC curves
    plt.figure()
    for i in range(FLAGS.output_size):
        plt.plot(fpr[i],
                 tpr[i],
                 label='ROC curve of class {0} (area = {1:0.2f})'
                 ''.format(i, roc_auc[i]))
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.0])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(
        'Some extension of Receiver operating characteristic to multi-class')
    plt.legend(loc="lower right")
    plt.show()
コード例 #4
0
def debug():
    """
	Function to debug model and dataset implemention.
	It run only one batch.
	No model or config file are created.

	Parameters
	----------

	Returns
	-------

	"""

    # Build the model
    network = model.Model(
        n_inputs=FLAGS.note_size,
        n_outputs=FLAGS.output_size,
        n_hidden=FLAGS.n_hidden,
        n_step=FLAGS.future_size + FLAGS.past_size + 1,
        num_layers=FLAGS.num_layers,
        lr=FLAGS.learning_rate,
        momentum=FLAGS.momentum,
        target_note=FLAGS.past_size - 1,
    )

    # Initialize session and Tensorboard
    session = tf.Session()
    session.run(tf.global_variables_initializer())

    parameters = {
        "directory": FLAGS.validation_dir,
        "batch_size": FLAGS.batch_size,
        "past": FLAGS.past_size,
        "future": FLAGS.future_size,
        "note_size": FLAGS.note_size,
        "output_size": FLAGS.output_size,
        "epsilon": FLAGS.epsilon_min,
        "meaning": True
    }
    # Pre-processing Dataset for Validation
    #validset = dataset.ValidSet(**parameters)
    validset = ValidationSet(**parameters)
    #for note in validset.datas[0]:
    #	print note[0:14], note[-12:]
    valid_x, valid_y = validset.next_batch(2 * FLAGS.batch_size)
    #valid_x_p, valid_y_p = validset.next_batch(FLAGS.batch_size, for_future=False)
    del validset
    print "[INFO] Validation Dataset Created"

    # Pre-processing Dataset for Testing
    parameters["directory"] = FLAGS.test_dir
    #testset = dataset.TestSet(**parameters)
    testset = TestingSet(**parameters)
    test_x, test_y = testset.next_transpositions_batch()
    del testset
    print "[INFO] Test Dataset Created"

    # Pre-processing Dataset for Training
    parameters["directory"] = FLAGS.training_dir
    parameters["epsilon"] = FLAGS.epsilon_max
    #sets = dataset.TrainSet(**parameters)
    trainset = TrainingSet(**parameters)
    print "[INFO] Training Dataset Created"

    # Debug session
    batch_x, batch_y = trainset.next_batch()
    network.debug(session, batch_x, batch_y)
    print "[INFO] Debugging batch done"
    network.debug(session, valid_x, valid_y)
    print "[INFO] Debugging valid done"
    network.debug(session, test_x, test_y)
    print "[INFO] Debugging test done"
    print "[INFO] Feeding Done"

    predicted, real = network.test(session, valid_x, valid_y)
    dict_acc = utils.metrics(utils.perf_measure(predicted, real))
    for x in dict_acc:
        print x

    # Write parameters and hyper-parameters into a file named output.txt
    with open(
            os.path.join(utils.getRoot(), FLAGS.config_dir,
                         "".join(["debug", ".txt"])), "wb") as debug_file:
        #debug_file.write("Time 		: {}\n".format(end - start))
        debug_file.write("Step 		: {}\n".format(FLAGS.max_step))
        debug_file.write("LR 		: {}\n".format(FLAGS.learning_rate))
        debug_file.write("Momentum 	: {}\n".format(FLAGS.momentum))
        debug_file.write("Past Size 	: {}\n".format(FLAGS.past_size))
        debug_file.write("Future Size 	: {}\n".format(FLAGS.future_size))
        debug_file.write("Batch Size 	: {}\n".format(FLAGS.batch_size))
        debug_file.write("Hidden Size 	: {}\n".format(FLAGS.n_hidden))
        debug_file.write("M-epsilon	: {}\n".format(FLAGS.epsilon_max))
        debug_file.write("m-Epsilon	: {}\n".format(FLAGS.epsilon_min))
        debug_file.write("Rhythm Size 	: {}\n\n".format(cv.TOTAL_SIZE))
        debug_file.write("Hasard 		: {}\n".format(0.5))
コード例 #5
0
def run():
    """
	Function to run/build the neural network.
	Logs (for tensorboard) are stored in src/logs/
	A config file is created. It consists of how long it took, hyperparameters, and so on. (see end of the function).
	Parameters can be changed, see above.

	Parameters
	----------

	Returns
	-------
	"""

    # Build the model
    network = model.Model(
        n_inputs=FLAGS.note_size,
        n_outputs=FLAGS.output_size,
        n_hidden=FLAGS.n_hidden,
        n_step=FLAGS.future_size + FLAGS.past_size + 1,
        num_layers=FLAGS.num_layers,
        lr=FLAGS.learning_rate,
        momentum=FLAGS.momentum,
        target_note=FLAGS.past_size - 1,
    )

    # Initialize session and Tensorboard
    index = 1
    while os.path.isdir(
            os.path.join(utils.getPath(FLAGS.logs_dir, FLAGS.train_dir),
                         str(index))):
        index += 1

    session = tf.Session()
    session.run(tf.global_variables_initializer())

    # Training Summary
    tf.summary.scalar("loss", network.ferror)

    tf.summary.scalar("Accuracy", network.faccuracy)
    summary_op = tf.summary.merge_all()
    writer_dir = os.path.join(utils.getPath(FLAGS.logs_dir, FLAGS.train_dir),
                              str(index))
    os.makedirs(writer_dir)
    writer = tf.summary.FileWriter(writer_dir, graph=tf.get_default_graph())

    # Validation Summary
    summary_valid_op = tf.summary.merge_all()
    writer_dir = os.path.join(utils.getPath(FLAGS.logs_dir, FLAGS.valid_dir),
                              str(index))
    os.makedirs(writer_dir)
    writer_valid = tf.summary.FileWriter(writer_dir,
                                         graph=tf.get_default_graph())

    # Testing Summary
    summary_test_op = tf.summary.merge_all()
    writer_dir = os.path.join(utils.getPath(FLAGS.logs_dir, FLAGS.test_wdir),
                              str(index))
    os.makedirs(writer_dir)
    writer_test = tf.summary.FileWriter(writer_dir,
                                        graph=tf.get_default_graph())

    # Saving
    writer_dir = os.path.join(utils.getRoot(), FLAGS.save_dir, str(index))
    os.makedirs(writer_dir)

    parameters = {
        "directory": FLAGS.validation_dir,
        "batch_size": FLAGS.batch_size,
        "past": FLAGS.past_size,
        "future": FLAGS.future_size,
        "note_size": FLAGS.note_size,
        "output_size": FLAGS.output_size,
        "epsilon": FLAGS.epsilon_min,
        "meaning": True
    }
    # Pre-processing Dataset for Validation
    #validset = dataset.ValidSet(**parameters)
    validset = ValidationSet(**parameters)
    valid_x, valid_y = validset.next_batch(206)
    #valid_x_p, valid_y_p = validset.next_batch(40*FLAGS.batch_size, for_future=False)
    # Eval batch
    eval_x, eval_y = validset.eval_batch(206, for_future=True)
    tmp_x, tmp_y = validset.eval_batch(FLAGS.batch_size, for_future=False)

    del validset
    print "[INFO] Validation Dataset Created"

    # Pre-processing Dataset for Testing
    parameters["directory"] = FLAGS.test_dir
    #testset = dataset.TestSet(**parameters)
    testset = TestingSet(**parameters)
    test_x, test_y = testset.next_batch(206)
    #test_x_p, test_y_p = testset.next_batch(40*FLAGS.batch_size, for_future=False)
    del testset
    print "[INFO] Test Dataset Created"

    # Pre-processing Dataset for Training
    parameters["directory"] = FLAGS.training_dir
    parameters["epsilon"] = FLAGS.epsilon_max
    #sets = dataset.TrainSet(**parameters)
    trainset = TrainingSet(**parameters)
    print "[INFO] Training Dataset Created"
    batch_x, batch_y = trainset.next_batch()
    #---
    # It's high time we fed our model !
    #---
    start = time.time()
    for step in xrange(FLAGS.max_step):
        for i in xrange(FLAGS.max_iter):
            # Feed the model
            summary, _ = network.feed(session, summary_op, batch_x, batch_y)

            # Write log
            writer.add_summary(summary, step * FLAGS.max_iter + i)

        batch_x, batch_y = trainset.next_transpositions_batch()

        # Validation
        summary_valid, _ = network.eval(session, summary_valid_op, valid_x,
                                        valid_y)

        # Write Log
        writer_valid.add_summary(summary_valid, step * FLAGS.max_iter)

        # Test the model
        summary_test, _ = network.eval(session, summary_test_op, test_x,
                                       test_y)

        # Write Log
        writer_test.add_summary(summary_test, step * FLAGS.max_iter)

        # Display accuracy and loss every 200 steps time
        print("Iterations 	: {}".format(step))
        if (step + 1) % 10 == 0:
            trainset.epsilon = trainset.epsilon - 2 if trainset.epsilon > 1 else 1
            saver = tf.train.Saver()
            #saver.save(session, utils.getPath(writer_dir, FLAGS.save + str(step*FLAGS.max_iter)))
    """for step in xrange(FLAGS.max_step, 2*FLAGS.max_step):
		# Get next batch
		batch_x, batch_y = trainset.next_batch(for_future=False)

		# Feed the model
		summary, _ = network.feed_past(session, summary_op, batch_x, batch_y)

		# Write log
		writer.add_summary(summary, step)

		# Validation
		summary_valid, _ = network.eval_past(session, summary_valid_op, valid_x_p, valid_y_p)

		# Write Log
		writer_valid.add_summary(summary_valid, step)

		# Test the model
		summary_test, _ = network.eval_past(session, summary_test_op, test_x_p, test_y_p)

		# Write Log
		writer_test.add_summary(summary_test, step)

		# Display accuracy and loss every 200 steps time
		#if step % FLAGS.display_time == 0:
		#	print("Iterations 	: {}".format(step))"""
    # Save the model
    #writer_dir = os.path.join(utils.getRoot(), FLAGS.save_dir, str(index))
    #os.makedirs(writer_dir)
    saver = tf.train.Saver()
    saver.save(session, utils.getPath(writer_dir, FLAGS.save))

    end = time.time()
    predicted, real = network.test(session, eval_x, eval_y)
    dict_acc = utils.metrics(utils.perf_measure(predicted, real))
    print "On evaluation dataset"
    for x in dict_acc:
        print x

    predicted, real = network.test(session, valid_x, valid_y)
    dict_acc = utils.metrics(utils.perf_measure(predicted, real))
    print "On validation dataset"
    for x in dict_acc:
        print x

    # Write parameters and hyper-parameters into a file named output.txt
    with open(
            os.path.join(utils.getRoot(), FLAGS.config_dir,
                         "".join([str(index), ".txt"])), "wb") as config_file:
        config_file.write("Time 		: {}\n".format(end - start))
        config_file.write("Step 		: {}\n".format(FLAGS.max_step))
        config_file.write("LR 		: {}\n".format(FLAGS.learning_rate))
        config_file.write("Momentum 	: {}\n".format(FLAGS.momentum))
        config_file.write("Past Size 	: {}\n".format(FLAGS.past_size))
        config_file.write("Future Size 	: {}\n".format(FLAGS.future_size))
        config_file.write("Batch Size 	: {}\n".format(FLAGS.batch_size))
        config_file.write("Hidden Size 	: {}\n".format(FLAGS.n_hidden))
        config_file.write("M-epsilon	: {}\n".format(FLAGS.epsilon_max))
        config_file.write("m-Epsilon	: {}\n".format(FLAGS.epsilon_min))
        config_file.write("Rhythm Size 	: {}\n".format(cv.TOTAL_SIZE))
        config_file.write("Iter Size 	: {}\n".format(FLAGS.max_iter))
        config_file.write("Step Size 	: {}\n".format(FLAGS.max_step))
        config_file.write("Hasard 		: {}\n".format(0.5))
コード例 #6
0
ファイル: main.py プロジェクト: xxlya/2CC3D_pytorch
def main():

    opt = parse_opts()
    print(opt)
    #pdb.set_trace()
    if not os.path.exists(opt.result_path):
        os.mkdir(opt.result_path)
    with open(os.path.join(opt.result_path, 'opts.json'), 'w') as opt_file:
        json.dump(vars(opt), opt_file)
    device = torch.device("cuda" if opt.use_cuda else "cpu")

    # Read Phenotype
    csv = pd.read_csv(opt.csv_dir)

    if opt.cross_val:

        for fold in range(5):  # change back to 5
            train_ID = dd.io.load(os.path.join(opt.MAT_dir,
                                               opt.splits))[fold]['X_train']
            val_ID = dd.io.load(os.path.join(opt.MAT_dir,
                                             opt.splits))[fold]['X_test']

            # ==========================================================================#
            #                       1. Network Initialization                          #
            # ==========================================================================#
            torch.manual_seed(opt.manual_seed)
            if opt.architecture == 'ResNet':
                kwargs = {
                    'inchn': opt.win_size,
                    'sample_size': opt.sample_size,
                    'sample_duration': opt.sample_duration,
                    'num_classes': opt.n_classes
                }
                model = resnet10(**kwargs).to(device)
            elif opt.architecture == 'NC3D':
                model = MyNet(opt.win_size, opt.nb_filter,
                              opt.batch_size).to(device)
            elif opt.architecture == 'CRNN':
                model = CNN_LSTM(opt.win_size, opt.nb_filter, opt.batch_size,
                                 opt.sample_size, opt.sample_duration,
                                 opt.rep).to(device)
            else:
                print('Architecture is not available.')
                raise LookupError
            print(model)
            model_parameters = filter(lambda p: p.requires_grad,
                                      model.parameters())
            num_params = sum([np.prod(p.size()) for p in model_parameters])
            print('number of trainable parameters:', num_params)
            device = torch.device(
                'cuda' if torch.cuda.is_available() else 'cpu')
            class_weights = torch.FloatTensor(opt.weights).to(device)
            criterion = nn.CrossEntropyLoss(weight=class_weights)
            criterion.to(device)

            # ==========================================================================#
            #                     2. Setup Dataloading Paramters                       #
            # ==========================================================================#
            '''load subjects ID'''
            ID = csv['SUB_ID'].values
            win_size = opt.win_size  # num of channel input
            T = opt.sample_duration  # total length of fMRI
            num_rep = T // win_size  # num of repeat the ID

            # ==========================================================================#
            #                     3. Training and Validation                            #
            # ==========================================================================#

            if opt.architecture == 'ResNet':
                training_data = fMRIDataset(opt.datadir, win_size, train_ID, T,
                                            csv)
            elif opt.architecture == 'NC3D':
                training_data = fMRIDataset_2C(opt.datadir, train_ID)
            elif opt.architecture == 'CRNN':
                training_data = fMRIDataset_CRNN(opt.datadir, win_size,
                                                 train_ID, T, csv)
            else:
                print('Architecture is not available.')
                raise LookupError
            train_loader = torch.utils.data.DataLoader(
                training_data,
                batch_size=opt.batch_size,
                shuffle=True,
                num_workers=opt.n_threads,
                pin_memory=False)
            log_path = os.path.join(opt.result_path, str(fold))
            if not os.path.exists(log_path):
                os.mkdir(log_path)
            train_logger = Logger(os.path.join(log_path, 'train.log'),
                                  ['epoch', 'loss', 'acc', 'lr'])
            train_batch_logger = Logger(
                os.path.join(log_path, 'train_batch.log'),
                ['epoch', 'batch', 'iter', 'loss', 'acc', 'lr'])
            '''optimization'''
            if opt.nesterov:
                dampening = 0
            else:
                dampening = opt.dampening
            if opt.optimizer == 'sgd':
                optimizer = optim.SGD(model.parameters(),
                                      lr=opt.learning_rate,
                                      momentum=opt.momentum,
                                      dampening=dampening,
                                      weight_decay=opt.weight_decay,
                                      nesterov=opt.nesterov)
            elif opt.optimizer == 'adam':
                optimizer = optim.Adam(model.parameters(),
                                       lr=opt.learning_rate,
                                       weight_decay=opt.weight_decay)
            elif opt.optimizer == 'adadelta':
                optimizer = optim.Adadelta(model.parameters(),
                                           lr=opt.learning_rate,
                                           weight_decay=opt.weight_decay)
            scheduler = lr_scheduler.ReduceLROnPlateau(
                optimizer, 'min', patience=opt.lr_patience)
            scheduler = lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.2)
            if not opt.no_val:
                if opt.architecture == 'ResNet':
                    validation_data = fMRIDataset(opt.datadir, win_size,
                                                  val_ID, T, csv)
                elif opt.architecture == 'NC3D':
                    validation_data = fMRIDataset_2C(opt.datadir, val_ID)
                elif opt.architecture == 'CRNN':
                    validation_data = fMRIDataset_CRNN(opt.datadir, win_size,
                                                       val_ID, T, csv)
                val_loader = torch.utils.data.DataLoader(
                    validation_data,
                    batch_size=opt.n_val_samples,
                    shuffle=False,
                    num_workers=opt.n_threads,
                    pin_memory=False)
                val_logger = Logger(os.path.join(log_path, 'val.log'),
                                    ['epoch', 'loss', 'acc'])

            if opt.resume_path:
                print('loading checkpoint {}'.format(opt.resume_path))
                checkpoint = torch.load(opt.resume_path)
                # assert opt.arch == checkpoint['arch']

                opt.begin_epoch = checkpoint['epoch']
                model.load_state_dict(checkpoint['state_dict'])
                if not opt.no_train:
                    optimizer.load_state_dict(checkpoint['optimizer'])

            print('run')
            best_loss = 1e4
            for i in range(opt.begin_epoch, opt.n_epochs + 1):
                if not opt.no_train:
                    train_epoch(i, train_loader, model, criterion, optimizer,
                                opt, log_path, train_logger,
                                train_batch_logger)
                if not opt.no_val:
                    validation_loss = val_epoch(i, val_loader, model,
                                                criterion, opt, val_logger)
                    if validation_loss < best_loss:
                        best_loss = validation_loss
                        best_model_wts = copy.deepcopy(model.state_dict())
                        torch.save(
                            best_model_wts,
                            os.path.join(log_path,
                                         str(fold) + '_best.pth'))
                if not opt.no_train and not opt.no_val:
                    #scheduler.step(validation_loss)
                    scheduler.step()

                model_wts = copy.deepcopy(model.state_dict())
                torch.save(
                    model_wts,
                    os.path.join(log_path,
                                 str(fold) + '_epoch_' + str(i) + '.pth'))

            # =========================================================================#
            #                            4. Testing                                    #
            # =========================================================================#

            if opt.test:
                model = MyNet(opt.win_size, opt.nb_filter,
                              opt.batch_size).to(device)
                model.load_state_dict(
                    torch.load(os.path.join(log_path,
                                            str(fold) + '_best.pth')))
                test_details_logger = Logger(
                    os.path.join(log_path, 'test_details.log'),
                    ['sub_id', 'pos', 'neg'])
                test_logger = Logger(os.path.join(log_path, 'test.log'), [
                    'fold', 'real_Y', 'pred_Y', 'acc', 'sen', 'spec', 'ppv',
                    'npv'
                ])
                real_Y = []
                pred_Y = []
                model.eval()
                if opt.no_val:
                    if opt.architecture == 'ResNet':
                        validation_data = fMRIDataset(opt.datadir, win_size,
                                                      val_ID, T, csv)
                    elif opt.architecture == 'NC3D':
                        validation_data = fMRIDataset_2C(opt.datadir, val_ID)
                    elif opt.architecture == 'CRNN':
                        validation_data = fMRIDataset_CRNN(
                            opt.datadir, win_size, val_ID, T, csv)
                test_loader = torch.utils.data.DataLoader(
                    validation_data,
                    batch_size=146 + 1 - opt.s_sz,
                    shuffle=False,
                    num_workers=opt.n_threads,
                    pin_memory=False)
                with torch.no_grad():
                    for i, (inputs, targets) in enumerate(test_loader):
                        real_Y.append(targets[0])
                        inputs, targets = inputs.to(device), targets.to(device)
                        inputs = Variable(inputs).float()
                        targets = Variable(targets).long()
                        outputs = model(inputs)
                        rest = np.argmax(outputs.detach().cpu().numpy(),
                                         axis=1)
                        pos = np.sum(rest == targets.detach().cpu().numpy())
                        neg = len(rest) - pos
                        print('pos:', pos, '  and neg:', neg)
                        test_details_logger.log({
                            'sub_id': val_ID[i * 142],
                            'pos': pos,
                            'neg': neg
                        })
                        if np.sum(rest == 1) >= np.sum(rest == 0):
                            pred_Y.append(1)
                        else:
                            pred_Y.append(0)
                TP, FP, TN, FN = perf_measure(real_Y, pred_Y)
                acc = (TP + TN) / (TP + TN + FP + FN)
                sen = TP / (TP + FN)
                spec = TN / (TN + FP)
                ppv = TP / (TP + FP)
                npv = TN / (TN + FN)
                test_logger.log({
                    'fold': fold,
                    'real_Y': real_Y,
                    'pred_Y': pred_Y,
                    'acc': acc,
                    'sen': sen,
                    'spec': spec,
                    'ppv': ppv,
                    'npv': npv
                })

    else:

        fold = opt.fold
        train_ID = dd.io.load(os.path.join(opt.MAT_dir,
                                           opt.splits))[fold]['X_train']
        val_ID = dd.io.load(os.path.join(opt.MAT_dir,
                                         opt.splits))[fold]['X_test']

        # ==========================================================================#
        #                       1. Network Initialization                          #
        # ==========================================================================#
        torch.manual_seed(opt.manual_seed)
        if opt.architecture == 'ResNet':
            kwargs = {
                'inchn': opt.win_size,
                'sample_size': opt.sample_size,
                'sample_duration': opt.sample_duration,
                'num_classes': opt.n_classes
            }
            model = resnet10(**kwargs).to(device)
        elif opt.architecture == 'NC3D':
            model = MyNet(opt.win_size, opt.nb_filter,
                          opt.batch_size).to(device)
        elif opt.architecture == 'CRNN':
            model = CNN_LSTM(opt.win_size, opt.nb_filter, opt.batch_size,
                             opt.s_sz, opt.sample_duration, opt.rep).to(device)
        else:
            print('Architecture is not available.')
            raise LookupError
        print(model)
        model_parameters = filter(lambda p: p.requires_grad,
                                  model.parameters())
        num_params = sum([np.prod(p.size()) for p in model_parameters])
        print('number of trainable parameters:', num_params)
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        class_weights = torch.FloatTensor(opt.weights).to(device)
        criterion = nn.CrossEntropyLoss(weight=class_weights)
        criterion.to(device)

        # ==========================================================================#
        #                     2. Setup Dataloading Paramters                       #
        # ==========================================================================#
        '''load subjects ID'''
        win_size = opt.win_size  # num of channel input
        T = opt.sample_duration  # total length of fMRI

        # ==========================================================================#
        #                     3. Training and Validation                            #
        # ==========================================================================#
        # repeat the ID, in order to visit all the volumes in fMRI, this will be input to the dataloader
        if opt.architecture == 'ResNet':
            training_data = fMRIDataset(opt.datadir, opt.s_sz, train_ID, T,
                                        csv, opt.rep)
        elif opt.architecture == 'NC3D':
            training_data = fMRIDataset_2C(opt.datadir, train_ID)
        elif opt.architecture == 'CRNN':
            training_data = fMRIDataset_CRNN(opt.datadir, opt.s_sz, train_ID,
                                             T, csv, opt.rep)
        train_loader = torch.utils.data.DataLoader(
            training_data,
            batch_size=opt.batch_size,
            shuffle=True,
            #num_workers=opt.n_threads,
            pin_memory=True)
        log_path = opt.result_path
        print('log_path', log_path)
        if not os.path.exists(log_path):
            os.mkdir(log_path)
        train_logger = Logger(os.path.join(log_path, 'train.log'),
                              ['epoch', 'loss', 'acc', 'lr'])
        train_batch_logger = Logger(
            os.path.join(log_path, 'train_batch.log'),
            ['epoch', 'batch', 'iter', 'loss', 'acc', 'lr'])
        '''optimization'''
        if opt.nesterov:
            dampening = 0
        else:
            dampening = opt.dampening
        if opt.optimizer == 'sgd':
            optimizer = optim.SGD(model.parameters(),
                                  lr=opt.learning_rate,
                                  momentum=opt.momentum,
                                  dampening=dampening,
                                  weight_decay=opt.weight_decay,
                                  nesterov=opt.nesterov)
        elif opt.optimizer == 'adam':
            optimizer = optim.Adam(model.parameters(),
                                   lr=opt.learning_rate,
                                   weight_decay=opt.weight_decay)
        elif opt.optimizer == 'adadelta':
            optimizer = optim.Adadelta(model.parameters(),
                                       lr=opt.learning_rate,
                                       weight_decay=opt.weight_decay)
        # scheduler = lr_scheduler.ReduceLROnPlateau(
        #     optimizer, 'min', patience=opt.lr_patience)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.1)
        if not opt.no_val:
            if opt.architecture == 'ResNet':
                validation_data = fMRIDataset(opt.datadir, opt.s_sz, val_ID, T,
                                              csv, opt.rep)
            elif opt.architecture == 'NC3D':
                validation_data = fMRIDataset_2C(opt.datadir, val_ID)
            elif opt.architecture == 'CRNN':
                validation_data = fMRIDataset_CRNN(opt.datadir, opt.s_sz,
                                                   val_ID, T, csv, opt.rep)
            val_loader = torch.utils.data.DataLoader(
                validation_data,
                batch_size=opt.n_val_samples,
                shuffle=False,
                #num_workers=opt.n_threads,
                pin_memory=True)
            val_logger = Logger(os.path.join(log_path, 'val.log'),
                                ['epoch', 'loss', 'acc'])

        if opt.resume_path:
            print('loading checkpoint {}'.format(opt.resume_path))
            checkpoint = torch.load(opt.resume_path)
            # assert opt.arch == checkpoint['arch']

            opt.begin_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            if not opt.no_train:
                optimizer.load_state_dict(checkpoint['optimizer'])

        print('run')
        for i in range(opt.begin_epoch, opt.n_epochs + 1):
            if not opt.no_train:
                train_epoch(i, train_loader, model, criterion, optimizer, opt,
                            log_path, train_logger, train_batch_logger)
            if not opt.no_val:  # when epoch is greater then 5, we start to do validation
                validation_loss = val_epoch(i, val_loader, model, criterion,
                                            opt, val_logger)

            if not opt.no_train and not opt.no_val:
                scheduler.step(validation_loss)

        # =========================================================================#
        #                            4. Testing                                    #
        # =========================================================================#

        if opt.test:
            test_details_logger = Logger(
                os.path.join(opt.result_path, 'test_details.log'),
                ['sub_id', 'pos', 'neg'])
            test_logger = Logger(os.path.join(opt.result_path, 'test.log'), [
                'fold', 'real_Y', 'pred_Y', 'acc', 'sen', 'spec', 'ppv', 'npv'
            ])
            real_Y = []
            pred_Y = []
            model.eval()
            test_loader = torch.utils.data.DataLoader(
                validation_data,
                batch_size=142,
                shuffle=False,
                num_workers=opt.n_threads,
                pin_memory=False)
            with torch.no_grad():
                for i, (inputs, targets) in enumerate(test_loader):
                    real_Y.append(targets)
                    inputs, targets = inputs.to(device), targets.to(device)
                    inputs = Variable(inputs).float()
                    targets = Variable(targets).long()
                    outputs = model(inputs)
                    rest = np.argmax(outputs.detach().cpu().numpy(), axis=1)
                    pred_Y.append(outputs.detach().cpu().numpu())
                    pos = np.sum(rest == targets.detach().cpu().numpu())
                    neg = len(rest) - pos
                    #print('pos:', pos, '  and neg:', neg)
                    test_details_logger.log({
                        'sub_id': val_ID[i * 142],
                        'pos': pos,
                        'neg': neg
                    })
            TP, FP, TN, FN = perf_measure(real_Y, pred_Y)
            acc = (TP + TN) / (TP + TN + FP + FN)
            sen = TP / (TP + FN)
            spec = TN / (TN + FP)
            ppv = TP / (TP + FP)
            npv = TN / (TN + FN)
            test_logger.log({
                'fold': fold,
                'real_Y': real_Y,
                'pred_Y': pred_Y,
                'acc': acc,
                'sen': sen,
                'spec': spec,
                'ppv': ppv,
                'npv': npv
            })