if __name__ == '__main__': parser = Flags() parser.set_arguments() FG = parser.parse_args() c_code, axis, z_dim = FG.c_code, FG.axis, FG.z_dim device = torch.device(FG.devices[0]) torch.cuda.set_device(FG.devices[0]) nets = [] for i in range(FG.fold): parser.configure('cur_fold', i) parser.configure('ckpt_dir') FG = parser.load() net = Baseline(FG.ckpt_dir, len(FG.labels)) net.to(device) net.load(epoch=None, optimizer=None, is_best=True) net.eval() nets += [net] #G = Generator(FG) G = torch.nn.DataParallel(Generator(z_dim, c_code, axis)) # state_dict = torch.load(os.path.join('BiGAN-info-c4-f', 'G.pth'), 'cpu') state_dict = torch.load(os.path.join('157-G8', 'G.pth'), 'cpu') G.load_state_dict(state_dict) G.to(device) G.eval() if axis == 1: ns = ((160, 160), (112, 112)) elif axis == 0:
def run_fold(parser, vis): devices = parser.args.devices parser.args.ckpt_dir = os.path.join('checkpoint', parser.args.model, 'f' + str(parser.args.cur_fold)) FG = parser.load() FG.devices = devices print(FG) torch.cuda.set_device(FG.devices[0]) device = torch.device(FG.devices[0]) net = Baseline(FG.ckpt_dir, len(FG.labels)) performances = net.load(epoch=None, is_best=True) net = net.to(device) trainloader, testloader = get_dataloader(k=FG.fold, cur_fold=FG.cur_fold, modality=FG.modality, axis=FG.axis, labels=FG.labels, batch_size=FG.batch_size) evaluator = create_supervised_evaluator( net, device=device, non_blocking=True, prepare_batch=process_ninecrop_batch, metrics={ 'sensitivity': Recall(False, mean_over_ninecrop), 'precision': Precision(False, mean_over_ninecrop), 'specificity': Specificity(False, mean_over_ninecrop) }) class Tracker(object): def __init__(self): self.data = [] outputs = Tracker() targets = Tracker() @evaluator.on(Events.ITERATION_COMPLETED) def transform_ninecrop_output(engine): output, target = engine.state.output if output.size(0) != target.size(0): n = target.size(0) npatches = output.size(0) // n output = output.view(n, npatches, *output.shape[1:]) output = torch.mean(output, dim=1) outputs.data += [output] targets.data += [target] evaluator.run(testloader) string = 'Fold {}'.format(FG.cur_fold) + '<br>' string += 'Epoch {}'.format(performances.pop('epoch')) + '<br>' for k in sorted(performances.keys()): string += k + ': ' + '{:.4f}'.format(performances[k]) string += '<br>' string += 'pre : ' + str(evaluator.state.metrics['precision']) + '<br>' string += 'sen : ' + str(evaluator.state.metrics['sensitivity']) + '<br>' string += 'spe : ' + str(evaluator.state.metrics['specificity']) + '<br>' vis.text(string, win=FG.model + '_result_fold{}'.format(FG.cur_fold)) del net return outputs.data, targets.data