Пример #1
0
def test_plot():
    with open(FLG.model + '_stat.pkl', 'rb') as f:
        reports = pickle.load(f)
    summary = Summary(port=39199, env=FLG.model)

    over_y_true = []
    over_y_score = []
    for running_fold in range(FLG.fold):
        report = reports[running_fold]
        over_y_true += report.y_true
        over_y_score += report.y_score
        summary.precision_recall_curve(report.y_true,
                                       report.y_score,
                                       true_label=0,
                                       name='fold' + str(running_fold))
        summary.roc_curve(report.y_true,
                          report.y_score,
                          true_label=0,
                          name='fold' + str(running_fold))

    summary.precision_recall_curve(over_y_true,
                                   over_y_score,
                                   true_label=0,
                                   name='over all')
    summary.roc_curve(over_y_true, over_y_score, true_label=0, name='over all')
Пример #2
0
    def __init__(self, config):
        self.config = config
        self.sample_data_queue = queue.Queue(
            maxsize=config['sample_queue_max_size'])

        #=========== Create Agent ==========
        env = IntraBuildingEnv("config.ini")
        self._mansion_attr = env._mansion.attribute
        self._obs_dim = obs_dim(self._mansion_attr)
        self._act_dim = act_dim(self._mansion_attr)

        self.config['obs_shape'] = self._obs_dim
        self.config['act_dim'] = self._act_dim

        model = RLDispatcherModel(self._act_dim)
        algorithm = IMPALA(model, hyperparas=config)
        self.agent = ElevatorAgent(algorithm, config, self.learn_data_provider)

        self.cache_params = self.agent.get_params()
        self.params_lock = threading.Lock()
        self.params_updated = False
        self.cache_params_sent_cnt = 0
        self.total_params_sync = 0

        #========== Learner ==========
        self.lr, self.entropy_coeff = None, None
        self.lr_scheduler = PiecewiseScheduler(config['lr_scheduler'])
        self.entropy_coeff_scheduler = PiecewiseScheduler(
            config['entropy_coeff_scheduler'])

        self.total_loss_stat = WindowStat(100)
        self.pi_loss_stat = WindowStat(100)
        self.vf_loss_stat = WindowStat(100)
        self.entropy_stat = WindowStat(100)
        self.kl_stat = WindowStat(100)
        self.learn_time_stat = TimeStat(100)
        self.start_time = None

        self.learn_thread = threading.Thread(target=self.run_learn)
        self.learn_thread.setDaemon(True)
        self.learn_thread.start()

        #========== Remote Actor ===========
        self.remote_count = 0

        self.batch_buffer = []
        self.remote_metrics_queue = queue.Queue()
        self.sample_total_steps = 0

        self.remote_manager_thread = threading.Thread(
            target=self.run_remote_manager)
        self.remote_manager_thread.setDaemon(True)
        self.remote_manager_thread.start()

        self.csv_logger = CSVLogger(
            os.path.join(logger.get_dir(), 'result.csv'))

        from utils import Summary
        self.summary = Summary('./output')
Пример #3
0
def main():
    # option flags
    FLG = train_args()

    # torch setting
    device = torch.device('cuda:{}'.format(FLG.devices[0]))
    torch.backends.cudnn.benchmark = True
    torch.cuda.set_device(FLG.devices[0])

    # create summary and report the option
    visenv = FLG.model
    summary = Summary(port=39199, env=visenv)
    summary.viz.text(argument_report(FLG, end='<br>'),
                     win='report' + str(FLG.running_fold))
    train_report = ScoreReport()
    valid_report = ScoreReport()
    timer = SimpleTimer()
    fold_str = 'fold' + str(FLG.running_fold)
    best_score = dict(epoch=0, loss=1e+100, accuracy=0)

    #### create dataset ###
    # kfold split
    target_dict = np.load(pjoin(FLG.data_root, 'target_dict.pkl'))
    trainblock, validblock, ratio = fold_split(
        FLG.fold, FLG.running_fold, FLG.labels,
        np.load(pjoin(FLG.data_root, 'subject_indices.npy')), target_dict)

    def _dataset(block, transform):
        return ADNIDataset(FLG.labels,
                           pjoin(FLG.data_root, FLG.modal),
                           block,
                           target_dict,
                           transform=transform)

    # create train set
    trainset = _dataset(trainblock, transform_presets(FLG.augmentation))

    # create normal valid set
    validset = _dataset(
        validblock,
        transform_presets('nine crop' if FLG.augmentation ==
                          'random crop' else 'no augmentation'))

    # each loader
    trainloader = DataLoader(trainset,
                             batch_size=FLG.batch_size,
                             shuffle=True,
                             num_workers=4,
                             pin_memory=True)
    validloader = DataLoader(validset, num_workers=4, pin_memory=True)

    # data check
    # for image, _ in trainloader:
    #     summary.image3d('asdf', image)

    # create model
    def kaiming_init(tensor):
        return kaiming_normal_(tensor, mode='fan_out', nonlinearity='relu')

    if 'plane' in FLG.model:
        model = Plane(len(FLG.labels),
                      name=FLG.model,
                      weights_initializer=kaiming_init)
    elif 'resnet11' in FLG.model:
        model = resnet11(len(FLG.labels),
                         FLG.model,
                         weights_initializer=kaiming_init)
    elif 'resnet19' in FLG.model:
        model = resnet19(len(FLG.labels),
                         FLG.model,
                         weights_initializer=kaiming_init)
    elif 'resnet35' in FLG.model:
        model = resnet35(len(FLG.labels),
                         FLG.model,
                         weights_initializer=kaiming_init)
    elif 'resnet51' in FLG.model:
        model = resnet51(len(FLG.labels),
                         FLG.model,
                         weights_initializer=kaiming_init)
    else:
        raise NotImplementedError(FLG.model)

    print_model_parameters(model)
    model = torch.nn.DataParallel(model, FLG.devices)
    model.to(device)

    # criterion
    train_criterion = torch.nn.CrossEntropyLoss(weight=torch.Tensor(
        list(map(lambda x: x * 2, reversed(ratio))))).to(device)
    valid_criterion = torch.nn.CrossEntropyLoss().to(device)

    # TODO resume
    # optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=FLG.lr,
                                 weight_decay=FLG.l2_decay)
    # scheduler
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, FLG.lr_gamma)

    start_epoch = 0
    global_step = start_epoch * len(trainloader)
    pbar = None
    for epoch in range(1, FLG.max_epoch + 1):
        timer.tic()
        scheduler.step()
        summary.scalar('lr',
                       fold_str,
                       epoch - 1,
                       optimizer.param_groups[0]['lr'],
                       ytickmin=0,
                       ytickmax=FLG.lr)

        # train()
        torch.set_grad_enabled(True)
        model.train(True)
        train_report.clear()
        if pbar is None:
            pbar = tqdm(total=len(trainloader) * FLG.validation_term,
                        desc='Epoch {:<3}-{:>3} train'.format(
                            epoch, epoch + FLG.validation_term - 1))
        for images, targets in trainloader:
            images = images.cuda(device, non_blocking=True)
            targets = targets.cuda(device, non_blocking=True)

            optimizer.zero_grad()

            outputs = model(images)
            loss = train_criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            train_report.update_true(targets)
            train_report.update_score(F.softmax(outputs, dim=1))

            summary.scalar('loss',
                           'train ' + fold_str,
                           global_step / len(trainloader),
                           loss.item(),
                           ytickmin=0,
                           ytickmax=1)

            pbar.update()
            global_step += 1

        if epoch % FLG.validation_term != 0:
            timer.toc()
            continue
        pbar.close()

        # valid()
        torch.set_grad_enabled(False)
        model.eval()
        valid_report.clear()
        pbar = tqdm(total=len(validloader),
                    desc='Epoch {:>3} valid'.format(epoch))
        for images, targets in validloader:
            true = targets
            npatchs = 1
            if len(images.shape) == 6:
                _, npatchs, c, x, y, z = images.shape
                images = images.view(-1, c, x, y, z)
                targets = torch.cat([targets
                                     for _ in range(npatchs)]).squeeze()
            images = images.cuda(device, non_blocking=True)
            targets = targets.cuda(device, non_blocking=True)

            output = model(images)
            loss = valid_criterion(output, targets)

            valid_report.loss += loss.item()

            if npatchs == 1:
                score = F.softmax(output, dim=1)
            else:
                score = torch.mean(F.softmax(output, dim=1),
                                   dim=0,
                                   keepdim=True)
            valid_report.update_true(true)
            valid_report.update_score(score)

            pbar.update()
        pbar.close()

        # report
        vloss = valid_report.loss / len(validloader)
        summary.scalar('accuracy',
                       'train ' + fold_str,
                       epoch,
                       train_report.accuracy,
                       ytickmin=-0.05,
                       ytickmax=1.05)

        summary.scalar('loss',
                       'valid ' + fold_str,
                       epoch,
                       vloss,
                       ytickmin=0,
                       ytickmax=0.8)
        summary.scalar('accuracy',
                       'valid ' + fold_str,
                       epoch,
                       valid_report.accuracy,
                       ytickmin=-0.05,
                       ytickmax=1.05)

        is_best = False
        if best_score['loss'] > vloss:
            best_score['loss'] = vloss
            best_score['epoch'] = epoch
            best_score['accuracy'] = valid_report.accuracy
            is_best = True

        print('Best Epoch {}: validation loss {} accuracy {}'.format(
            best_score['epoch'], best_score['loss'], best_score['accuracy']))

        # save
        if isinstance(model, torch.nn.DataParallel):
            state_dict = model.module.state_dict()
        else:
            state_dict = model.state_dict()

        save_checkpoint(
            dict(epoch=epoch,
                 best_score=best_score,
                 state_dict=state_dict,
                 optimizer_state_dict=optimizer.state_dict()),
            FLG.checkpoint_root, FLG.running_fold, FLG.model, is_best)
        pbar = None
        timer.toc()
        print('Time elapse {}h {}m {}s'.format(*timer.total()))
Пример #4
0
        train_op = optimizer.minimize(model.losses["total"],
                                      global_step=global_step)

        # initialize variables
        uninitialized_variables = tfu_get_uninitialized_variables(session)
        if len(uninitialized_variables) > 0:
            session.run(tf.variables_initializer(uninitialized_variables))
        session.run(restore_ops)

        # initialize saver
        saver = tf.train.Saver(max_to_keep=100)

        # setup summaries
        summary_writer = tf.summary.FileWriter(run.base_path, session.graph)

        summary = Summary(session, summary_writer)
        summary.init_losses()
        summary.init_average_precisions(dataset.selected_labels)
        summary.init_mean_average_precision()
        summary.merge()

        # save graph.pbtxt
        tf.train.write_graph(session.graph_def, run.base_path, "graph.pbtxt")

        # setup some other stuff
        average_precision = AveragePrecision(matching_threshold=0.5)

        # initialize training loop
        epoch_value = session.run(epoch)
        global_step_value = session.run(global_step)
        logging_info("Training from epoch {}, step {}.".format(
Пример #5
0
class Learner(object):
    def __init__(self, config):
        self.config = config
        self.sample_data_queue = queue.Queue(
            maxsize=config['sample_queue_max_size'])

        #=========== Create Agent ==========
        env = IntraBuildingEnv("config.ini")
        self._mansion_attr = env._mansion.attribute
        self._obs_dim = obs_dim(self._mansion_attr)
        self._act_dim = act_dim(self._mansion_attr)

        self.config['obs_shape'] = self._obs_dim
        self.config['act_dim'] = self._act_dim

        model = RLDispatcherModel(self._act_dim)
        algorithm = IMPALA(model, hyperparas=config)
        self.agent = ElevatorAgent(algorithm, config, self.learn_data_provider)

        self.cache_params = self.agent.get_params()
        self.params_lock = threading.Lock()
        self.params_updated = False
        self.cache_params_sent_cnt = 0
        self.total_params_sync = 0

        #========== Learner ==========
        self.lr, self.entropy_coeff = None, None
        self.lr_scheduler = PiecewiseScheduler(config['lr_scheduler'])
        self.entropy_coeff_scheduler = PiecewiseScheduler(
            config['entropy_coeff_scheduler'])

        self.total_loss_stat = WindowStat(100)
        self.pi_loss_stat = WindowStat(100)
        self.vf_loss_stat = WindowStat(100)
        self.entropy_stat = WindowStat(100)
        self.kl_stat = WindowStat(100)
        self.learn_time_stat = TimeStat(100)
        self.start_time = None

        self.learn_thread = threading.Thread(target=self.run_learn)
        self.learn_thread.setDaemon(True)
        self.learn_thread.start()

        #========== Remote Actor ===========
        self.remote_count = 0

        self.batch_buffer = []
        self.remote_metrics_queue = queue.Queue()
        self.sample_total_steps = 0

        self.remote_manager_thread = threading.Thread(
            target=self.run_remote_manager)
        self.remote_manager_thread.setDaemon(True)
        self.remote_manager_thread.start()

        self.csv_logger = CSVLogger(
            os.path.join(logger.get_dir(), 'result.csv'))

        from utils import Summary
        self.summary = Summary('./output')


    def learn_data_provider(self):
        """ Data generator for fluid.layers.py_reader
        """
        while True:
            sample_data = self.sample_data_queue.get()
            self.sample_total_steps += sample_data['obs'].shape[0]
            self.batch_buffer.append(sample_data)

            buffer_size = sum(
                [data['obs'].shape[0] for data in self.batch_buffer])
            if buffer_size >= self.config['train_batch_size']:
                batch = {}
                for key in self.batch_buffer[0].keys():
                    batch[key] = np.concatenate(
                        [data[key] for data in self.batch_buffer])
                self.batch_buffer = []

                obs_np = batch['obs'].astype('float32')
                actions_np = batch['actions'].astype('int64')
                behaviour_logits_np = batch['behaviour_logits'].astype(
                    'float32')
                rewards_np = batch['rewards'].astype('float32')
                dones_np = batch['dones'].astype('float32')

                self.lr = self.lr_scheduler.step()
                self.entropy_coeff = self.entropy_coeff_scheduler.step()

                yield [
                    obs_np, actions_np, behaviour_logits_np, rewards_np,
                    dones_np, self.lr, self.entropy_coeff
                ]

    def run_learn(self):
        """ Learn loop
        """
        while True:
            with self.learn_time_stat:
                total_loss, pi_loss, vf_loss, entropy, kl = self.agent.learn()

            self.params_updated = True

            self.total_loss_stat.add(total_loss)
            self.pi_loss_stat.add(pi_loss)
            self.vf_loss_stat.add(vf_loss)
            self.entropy_stat.add(entropy)
            self.kl_stat.add(kl)

    def run_remote_manager(self):
        """ Accept connection of new remote actor and start sampling of the remote actor.
        """
        remote_manager = RemoteManager(port=self.config['server_port'])
        logger.info('Waiting for remote actors connecting.')
        while True:
            remote_actor = remote_manager.get_remote()
            self.remote_count += 1
            logger.info('Remote actor count: {}'.format(self.remote_count))
            if self.start_time is None:
                self.start_time = time.time()

            remote_thread = threading.Thread(
                target=self.run_remote_sample, args=(remote_actor, ))
            remote_thread.setDaemon(True)
            remote_thread.start()

    def run_remote_sample(self, remote_actor):
        """ Sample data from remote actor and update parameters of remote actor.
        """
        cnt = 0
        remote_actor.set_params(self.cache_params)
        while True:
            batch = remote_actor.sample()
            if batch:
                self.sample_data_queue.put(batch)

            cnt += 1
            if cnt % self.config['get_remote_metrics_interval'] == 0:
                metrics = remote_actor.get_metrics()
                if metrics:
                    self.remote_metrics_queue.put(metrics)

            self.params_lock.acquire()

            if self.params_updated and self.cache_params_sent_cnt >= self.config[
                    'params_broadcast_interval']:
                self.params_updated = False
                self.cache_params = self.agent.get_params()
                self.cache_params_sent_cnt = 0
            self.cache_params_sent_cnt += 1
            self.total_params_sync += 1

            self.params_lock.release()

            remote_actor.set_params(self.cache_params)

    def log_metrics(self):
        """ Log metrics of learner and actors
        """
        if self.start_time is None:
            return

        metrics = []
        while True:
            try:
                metric = self.remote_metrics_queue.get_nowait()
                metrics.append(metric)
            except queue.Empty:
                break

        episode_rewards, episode_steps, episode_shaping_rewards, episode_deliver_rewards, episode_wrong_deliver_rewards = [], [], [], [], []
        for x in metrics:
            episode_rewards.extend(x['episode_rewards'])
            episode_steps.extend(x['episode_steps'])
            episode_shaping_rewards.extend(x['episode_shaping_rewards'])
            episode_deliver_rewards.extend(x['episode_deliver_rewards'])
            episode_wrong_deliver_rewards.extend(x['episode_wrong_deliver_rewards'])

        max_episode_rewards, mean_episode_rewards, min_episode_rewards, \
                max_episode_steps, mean_episode_steps, min_episode_steps =\
                None, None, None, None, None, None
        mean_episode_shaping_rewards, mean_episode_deliver_rewards, mean_episode_wrong_deliver_rewards = \
                None, None, None
        if episode_rewards:
            mean_episode_rewards = np.mean(np.array(episode_rewards).flatten())
            max_episode_rewards = np.max(np.array(episode_rewards).flatten()) 
            min_episode_rewards = np.min(np.array(episode_rewards).flatten())

            mean_episode_steps = np.mean(np.array(episode_steps).flatten())
            max_episode_steps = np.max(np.array(episode_steps).flatten())
            min_episode_steps = np.min(np.array(episode_steps).flatten())
            
            mean_episode_shaping_rewards = np.mean(np.array(episode_shaping_rewards).flatten())
            mean_episode_deliver_rewards = np.mean(np.array(episode_deliver_rewards).flatten())
            mean_episode_wrong_deliver_rewards = np.mean(np.array(episode_wrong_deliver_rewards).flatten())

        #print(self.learn_time_stat.time_samples.items)
        
        metric = {
            'Sample steps': self.sample_total_steps,
            'max_episode_rewards': max_episode_rewards,
            'mean_episode_rewards': mean_episode_rewards,
            'min_episode_rewards': min_episode_rewards,
            'mean_episode_shaping_rewards': mean_episode_shaping_rewards,
            'mean_episode_deliver_rewards': mean_episode_deliver_rewards,
            'mean_episode_wrong_deliver_rewards': mean_episode_wrong_deliver_rewards,
            #'max_episode_steps': max_episode_steps,
            #'mean_episode_steps': mean_episode_steps,
            #'min_episode_steps': min_episode_steps,
            'sample_queue_size': self.sample_data_queue.qsize(),
            'total_params_sync': self.total_params_sync,
            'cache_params_sent_cnt': self.cache_params_sent_cnt,
            'total_loss': self.total_loss_stat.mean,
            'pi_loss': self.pi_loss_stat.mean,
            'vf_loss': self.vf_loss_stat.mean,
            'entropy': self.entropy_stat.mean,
            'kl': self.kl_stat.mean,
            'learn_time_s': self.learn_time_stat.mean,
            'elapsed_time_s': int(time.time() - self.start_time),
            'lr': self.lr,
            'entropy_coeff': self.entropy_coeff,
        }

        logger.info(metric)
        self.csv_logger.log_dict(metric)
        self.summary.log_dict(metric, self.sample_total_steps)

    def close(self):
        self.csv_logger.close()
Пример #6
0
import numpy as np
from skimage.exposure import rescale_intensity
import nibabel as nib
from utils import Summary

from tqdm import tqdm
from glob import glob
import nibabel as nib

if __name__ == '__main__':
    s = Summary(port=39199, env='sample')
    fns = glob('data/ADNI/original/pp/*.nii')
    for fn in fns:
        sample = nib.load(fn)
        image = sample.get_data().astype(float)
        s.image3d('asdf', image)
Пример #7
0
def test_cam():
    FLAGS.noadapt = True
    summary = Summary(port=39199, env=FLAGS.visdom_env)
    with open(FLAGS.model + '_stat.pkl', 'rb') as f:
        stat = pickle.load(f)

    class Feature(object):
        def __init__(self):
            self.blob = None

        def capture(self, blob):
            self.blob = blob

    # TODO: cropped cam?
    transformer = Compose([CenterCrop((112, 144, 112)), ToFloatTensor()])

    model = create_model()
    model = model.cuda(FLAGS.devices[0])
    for running_k in range(FLAGS.kfold):
        load_checkpoint(model,
                        FLAGS.checkpoint_root,
                        running_k,
                        FLAGS.model,
                        epoch=None,
                        is_best=True)
        model.eval()

        feature = Feature()

        def hook(mod, inp, oup):
            return feature.capture(oup.data.cpu().numpy())

        _ = model.layer4.register_forward_hook(hook)
        fc_weights = model.fc.weight.data.cpu().numpy()

        _, validset, _ =\
            make_kfold_dataset(FLAGS.kfold, running_k,
                               np.load(pjoin(FLAGS.root, 'sids.npy')),
                               np.load(pjoin(FLAGS.root, 'diagnosis.npz')),
                               FLAGS.labels, [FLAGS.dataset_root],
                               validset_loader=ScaledLoader(
                                   (112, 144, 112), False),
                               valid_transform=transformer)

        i, p = stat[running_k]['pos_best']['index'], stat[running_k][
            'pos_best']['p']
        i = 8
        image, target = validset[i]
        _ = model(Variable(image.view(1, *image.shape).cuda(FLAGS.devices[0])))

        name = 'k' + str(running_k)

        summary.cam3d(name + FLAGS.labels[target] + '_pos_best_' + str(p),
                      image,
                      feature.blob,
                      fc_weights,
                      target,
                      num_images=10)

        i, p = stat[running_k]['neg_best']['index'], stat[running_k][
            'neg_best']['p']
        i = 45
        image, target = validset[i]
        _ = model(Variable(image.view(1, *image.shape).cuda(FLAGS.devices[0])))

        summary.cam3d(name + FLAGS.labels[target] + '_neg_best_' + str(p),
                      image,
                      feature.blob,
                      fc_weights,
                      target,
                      num_images=10)
Пример #8
0
def test(FLG):
    device = torch.device('cuda:{}'.format(FLG.devices[0]))
    torch.set_grad_enabled(False)
    torch.backends.cudnn.benchmark = True
    torch.cuda.set_device(FLG.devices[0])
    report = [ScoreReport() for _ in range(FLG.fold)]
    overall_report = ScoreReport()
    target_dict = np.load(pjoin(FLG.data_root, 'target_dict.pkl'))

    with open(FLG.model + '_stat.pkl', 'rb') as f:
        stat = pickle.load(f)
    summary = Summary(port=10001, env=str(FLG.model) + 'CAM')

    class Feature(object):
        def __init__(self):
            self.blob = None

        def capture(self, blob):
            self.blob = blob

    if 'plane' in FLG.model:
        model = Plane(len(FLG.labels), name=FLG.model)
    elif 'resnet11' in FLG.model:
        model = resnet11(len(FLG.labels), FLG.model)
    elif 'resnet19' in FLG.model:
        model = resnet19(len(FLG.labels), FLG.model)
    elif 'resnet35' in FLG.model:
        model = resnet35(len(FLG.labels), FLG.model)
    elif 'resnet51' in FLG.model:
        model = resnet51(len(FLG.labels), FLG.model)
    else:
        raise NotImplementedError(FLG.model)
    model.to(device)

    ad_h = []
    nl_h = []
    adcams = np.zeros((4, 3, 112, 144, 112), dtype="f8")
    nlcams = np.zeros((4, 3, 112, 144, 112), dtype="f8")
    sb = [9.996e-01, 6.3e-01, 1.001e-01]
    for running_fold in range(FLG.fold):
        _, validblock, _ = fold_split(
            FLG.fold, running_fold, FLG.labels,
            np.load(pjoin(FLG.data_root, 'subject_indices.npy')), target_dict)
        validset = ADNIDataset(FLG.labels,
                               pjoin(FLG.data_root, FLG.modal),
                               validblock,
                               target_dict,
                               transform=transform_presets(FLG.augmentation))
        validloader = DataLoader(validset, pin_memory=True)

        epoch, _ = load_checkpoint(model, FLG.checkpoint_root, running_fold,
                                   FLG.model, None, True)
        model.eval()
        feature = Feature()

        def hook(mod, inp, oup):
            return feature.capture(oup.data.cpu().numpy())

        _ = model.layer4.register_forward_hook(hook)
        fc_weights = model.fc.weight.data.cpu().numpy()

        transformer = Compose([CenterCrop((112, 144, 112)), ToFloatTensor()])
        im, _ = original_load(validblock, target_dict, transformer, device)

        for image, target in validloader:
            true = target
            npatches = 1
            if len(image.shape) == 6:
                _, npatches, c, x, y, z = image.shape
                image = image.view(-1, c, x, y, z)
                target = torch.stack([target
                                      for _ in range(npatches)]).squeeze()
            image = image.cuda(device, non_blocking=True)
            target = target.cuda(device, non_blocking=True)

            output = model(image)

            if npatches == 1:
                score = F.softmax(output, dim=1)
            else:
                score = torch.mean(F.softmax(output, dim=1),
                                   dim=0,
                                   keepdim=True)

            report[running_fold].update_true(true)
            report[running_fold].update_score(score)

            overall_report.update_true(true)
            overall_report.update_score(score)

            print(target)
            if FLG.cam:
                s = 0
                cams = []
                if target[0] == 0:
                    s = score[0][0]
                    #s = s.cpu().numpy()[()]
                    cams = adcams
                else:
                    sn = score[0][1]
                    #s = s.cpu().numpy()[()]
                    cams = nlcams
                if s > sb[0]:
                    cams[0] = summary.cam3d(FLG.labels[target],
                                            im,
                                            feature.blob,
                                            fc_weights,
                                            target,
                                            cams[0],
                                            s,
                                            num_images=5)
                elif s > sb[1]:
                    cams[1] = summary.cam3d(FLG.labels[target],
                                            im,
                                            feature.blob,
                                            fc_weights,
                                            target,
                                            cams[1],
                                            s,
                                            num_images=5)
                elif s > sb[2]:
                    cams[2] = summary.cam3d(FLG.labels[target],
                                            im,
                                            feature.blob,
                                            fc_weights,
                                            target,
                                            cams[2],
                                            s,
                                            num_images=5)
                else:
                    cams[3] = summary.cam3d(FLG.labels[target],
                                            im,
                                            feature.blob,
                                            fc_weights,
                                            target,
                                            cams[3],
                                            s,
                                            num_images=5)
                #ad_h += [s]
                #nl_h += [sn]

        print('At {}'.format(epoch))
        print(
            metrics.classification_report(report[running_fold].y_true,
                                          report[running_fold].y_pred,
                                          target_names=FLG.labels,
                                          digits=4))
        print('accuracy {}'.format(report[running_fold].accuracy))

    #print(np.histogram(ad_h))
    #print(np.histogram(nl_h))

    print('over all')
    print(
        metrics.classification_report(overall_report.y_true,
                                      overall_report.y_pred,
                                      target_names=FLG.labels,
                                      digits=4))
    print('accuracy {}'.format(overall_report.accuracy))

    with open(FLG.model + '_stat.pkl', 'wb') as f:
        pickle.dump(report, f, pickle.HIGHEST_PROTOCOL)
Пример #9
0
def test_cam():
    with open(FLG.model + '_stat.pkl', 'rb') as f:
        stat = pickle.load(f)
    summary = Summary(port=10001, env=str(FLG.model) + 'CAM')

    class Feature(object):
        def __init__(self):
            self.blob = None

        def capture(self, blob):
            self.blob = blob

    # TODO: create model
    device = torch.device('cuda:{}'.format(FLG.devices[0]))
    torch.set_grad_enabled(False)
    torch.backends.cudnn.benchmark = True
    torch.cuda.set_device(FLG.devices[0])
    report = [ScoreReport() for _ in range(FLG.fold)]
    target_dict = np.load(pjoin(FLG.data_root, 'target_dict.pkl'))

    model = Plane(len(FLG.labels), name=FLG.model)
    model.to(device)

    transformer = Compose([CenterCrop((112, 144, 112)), ToFloatTensor()])

    def original_load(validblock):
        originalset = ADNIDataset(FLG.labels,
                                  pjoin(FLG.data_root, 'spm_normalized'),
                                  validblock,
                                  target_dict,
                                  transform=transformer)
        originloader = DataLoader(originalset, pin_memory=True)
        for image, target in originloader:
            if len(image.shape) == 6:
                _, npatches, c, x, y, z = image.shape
                image = image.view(-1, c, x, y, z)
                target = torch.stack([target
                                      for _ in range(npatches)]).squeeze()
            image = image.cuda(device, non_blocking=True)
            target = target.cuda(device, non_blocking=True)
            break
        return image, target

    hadcams = np.zeros((3, 112, 144, 112), dtype="f8")
    madcams = np.zeros((3, 112, 144, 112), dtype="f8")
    sadcams = np.zeros((3, 112, 144, 112), dtype="f8")
    zadcams = np.zeros((3, 112, 144, 112), dtype="f8")

    nlcams = np.zeros((4, 3, 112, 144, 112), dtype="f8")
    sb = [4.34444371e-16, 1.67179015e-18, 4.08813312e-23]
    #im, _ = original_load(validblock)
    for running_fold in range(FLG.fold):
        # validset
        _, validblock, _ = fold_split(
            FLG.fold, running_fold, FLG.labels,
            np.load(pjoin(FLG.data_root, 'subject_indices.npy')), target_dict)
        validset = ADNIDataset(FLG.labels,
                               pjoin(FLG.data_root, FLG.modal),
                               validblock,
                               target_dict,
                               transform=transformer)
        validloader = DataLoader(validset, pin_memory=True)

        load_checkpoint(model,
                        FLG.checkpoint_root,
                        running_fold,
                        FLG.model,
                        epoch=None,
                        is_best=True)
        model.eval()
        feature = Feature()

        def hook(mod, inp, oup):
            return feature.capture(oup.data.cpu().numpy())

        _ = model.layer4.register_forward_hook(hook)
        fc_weights = model.fc.weight.data.cpu().numpy()

        im, _ = original_load(validblock)
        ad_s = []
        for image, target in validloader:
            true = target
            npatches = 1
            if len(image.shape) == 6:
                _, npatches, c, x, y, z = image.shape
                image = image.view(-1, c, x, y, z)
                target = torch.stack([target
                                      for _ in range(npatches)]).squeeze()
            image = image.cuda(device, non_blocking=True)
            target = target.cuda(device, non_blocking=True)

            #_ = model(image.view(*image.shape))
            output = model(image)

            if npatches == 1:
                score = F.softmax(output, dim=1)
            else:
                score = torch.mean(F.softmax(output, dim=1),
                                   dim=0,
                                   keepdim=True)

            sa = score[0][1]
            #name = 'k'+str(running_fold)
            sa = sa.cpu().numpy()[()]
            print(score, score.shape)
            if true == torch.tensor([0]):
                if sa > sb[0]:
                    hadcams = summary.cam3d(FLG.labels[target],
                                            im,
                                            feature.blob,
                                            fc_weights,
                                            target,
                                            hadcams,
                                            sa,
                                            num_images=5)
                elif sa > sb[1]:
                    madcams = summary.cam3d(FLG.labels[target],
                                            im,
                                            feature.blob,
                                            fc_weights,
                                            target,
                                            madcams,
                                            sa,
                                            num_images=5)
                elif sa > sb[2]:
                    sadcams = summary.cam3d(FLG.labels[target],
                                            im,
                                            feature.blob,
                                            fc_weights,
                                            target,
                                            sadcams,
                                            sa,
                                            num_images=5)
                else:
                    zadcams = summary.cam3d(FLG.labels[target],
                                            im,
                                            feature.blob,
                                            fc_weights,
                                            target,
                                            zadcams,
                                            sa,
                                            num_images=5)
            else:
                if s > sb[0]:
                    nlcams[0] = summary.cam3d(FLG.labels[target],
                                              im,
                                              feature.blob,
                                              fc_weights,
                                              target,
                                              nlcams[0],
                                              sr,
                                              num_images=5)
                elif sr > sb[1]:
                    nlcams[1] = summary.cam3d(FLG.labels[target],
                                              im,
                                              feature.blob,
                                              fc_weights,
                                              target,
                                              nlcams[1],
                                              sr,
                                              num_images=5)
                elif sr > sb[2]:
                    nlcams[2] = summary.cam3d(FLG.labels[target],
                                              im,
                                              feature.blob,
                                              fc_weights,
                                              target,
                                              nlcams[2],
                                              sr,
                                              num_images=5)
                else:
                    nlcams[3] = summary.cam3d(FLG.labels[target],
                                              im,
                                              feature.blob,
                                              fc_weights,
                                              target,
                                              nlcams[3],
                                              sr,
                                              num_images=5)
            ad_s += [sr]
        print('histogram', np.histogram(ad_s))