Beispiel #1
0
def build_model(cfg, writer):
    print('Building model on ', end='', flush=True)
    t1 = time.time()
    device = torch.device('cuda:0')
    model = FusionNet(input_nc=cfg.TRAIN.input_nc,output_nc=cfg.TRAIN.output_nc,ngf=cfg.TRAIN.ngf).to(device)
    # n_data = list(cfg.DATA.patch_size)
    # rand_data = torch.rand(1,1,n_data[0],n_data[1])
    # writer.add_graph(model, (rand_data,))

    cuda_count = torch.cuda.device_count()
    if cuda_count > 1:
        if cfg.TRAIN.batch_size % cuda_count == 0:
            print('%d GPUs ... ' % cuda_count, end='', flush=True)
            model = nn.DataParallel(model)
        else:
            raise AttributeError('Batch size (%d) cannot be equally divided by GPU number (%d)' % (cfg.TRAIN.batch_size, cuda_count))
    else:
        print('a single GPU ... ', end='', flush=True)
    print('Done (time: %.2fs)' % (time.time() - t1))
    return model
Beispiel #2
0
def train():

    if not os.path.exists('train_model/'):
        os.makedirs('train_model/')
    if not os.path.exists('result/'):
        os.makedirs('result/')

    train_data, dev_data, word2id, char2id, opts = load_data(vars(args))
    model = FusionNet(opts)

    if args.use_cuda:
        model = model.cuda()

    dev_batches = get_batches(dev_data, args.batch_size)

    if args.eval:
        print('load model...')
        model.load_state_dict(torch.load(args.model_dir))
        model.eval()
        model.Evaluate(dev_batches,
                       args.data_path + 'dev_eval.json',
                       answer_file='result/' + args.model_dir.split('/')[-1] +
                       '.answers')
        exit()

    train_batches = get_batches(train_data, args.batch_size)
    total_size = len(train_batches)

    parameters = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = torch.optim.Adamax(parameters, lr=args.lrate)

    lrate = args.lrate
    best_score = 0.0
    f1_scores = []
    em_scores = []
    for epoch in range(1, args.epochs + 1):
        model.train()
        for i, train_batch in enumerate(train_batches):
            with torch.enable_grad():

                loss = model(train_batch)
                model.zero_grad()
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(parameters,
                                               opts['grad_clipping'])
                optimizer.step()
                model.reset_parameters()

            if i % 100 == 0:
                print(
                    'Epoch = %d, step = %d / %d, loss = %.5f, lrate = %.5f best_score = %.3f'
                    % (epoch, i, total_size, loss, lrate, best_score))
        with torch.no_grad():
            model.eval()
            exact_match_score, F1 = model.Evaluate(
                dev_batches,
                args.data_path + 'dev_eval.json',
                answer_file='result/' + args.model_dir.split('/')[-1] +
                '.answers')
            f1_scores.append(F1)
            em_scores.append(exact_match_score)
        with open(args.model_dir + '_f1_scores.pkl', 'wb') as f:
            pkl.dump(f1_scores, f)
        with open(args.model_dir + '_em_scores.pkl', 'wb') as f:
            pkl.dump(em_scores, f)

        if best_score < F1:
            best_score = F1
            print('saving %s ...' % args.model_dir)
            torch.save(model.state_dict(), args.model_dir)
        if epoch > 0 and epoch % args.decay_period == 0:
            lrate *= args.decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lrate
Beispiel #3
0
def main(args):

    # tensorboard
    logger_tb = logger.Logger(log_dir=args.experiment_name)

    # get dataset
    if args.dataset == "nuclei":
        train_dataset = NucleiDataset(args.train_data, 'train', args.transform,
                                      args.target_channels)
    elif args.dataset == "hpa":
        train_dataset = HPADataset(args.train_data, 'train', args.transform,
                                   args.max_mean, args.target_channels)
    elif args.dataset == "hpa_single":
        train_dataset = HPASingleDataset(args.train_data, 'train',
                                         args.transform)
    else:
        train_dataset = NeuroDataset(args.train_data, 'train', args.transform)

    # create dataloader
    train_params = {
        'batch_size': args.batch_size,
        'shuffle': False,
        'num_workers': args.num_workers
    }
    train_dataloader = DataLoader(train_dataset, **train_params)

    # device
    device = torch.device(args.device)

    # model
    if args.model == "fusion":
        model = FusionNet(args, train_dataset.dim)
    elif args.model == "dilation":
        model = DilationCNN(train_dataset.dim)
    elif args.model == "unet":
        model = UNet(args.num_kernel, args.kernel_size, train_dataset.dim,
                     train_dataset.target_dim)

    if args.device == "cuda":
        # parse gpu_ids for data paralle
        if ',' in args.gpu_ids:
            gpu_ids = [int(ids) for ids in args.gpu_ids.split(',')]
        else:
            gpu_ids = int(args.gpu_ids)

        # parallelize computation
        if type(gpu_ids) is not int:
            model = nn.DataParallel(model, gpu_ids)
    model.to(device)

    # optimizer
    parameters = model.parameters()
    if args.optimizer == "adam":
        optimizer = torch.optim.Adam(parameters, args.lr)
    else:
        optimizer = torch.optim.SGD(parameters, args.lr)

    # loss
    loss_function = dice_loss

    count = 0
    # train model
    for epoch in range(args.epoch):
        model.train()

        with tqdm.tqdm(total=len(train_dataloader.dataset),
                       unit=f"epoch {epoch} itr") as progress_bar:
            total_loss = []
            total_iou = []
            total_precision = []
            for i, (x_train, y_train) in enumerate(train_dataloader):

                with torch.set_grad_enabled(True):

                    # send data and label to device
                    x = torch.Tensor(x_train.float()).to(device)
                    y = torch.Tensor(y_train.float()).to(device)

                    # predict segmentation
                    pred = model.forward(x)

                    # calculate loss
                    loss = loss_function(pred, y)
                    total_loss.append(loss.item())

                    # calculate IoU precision
                    predictions = pred.clone().squeeze().detach().cpu().numpy()
                    gt = y.clone().squeeze().detach().cpu().numpy()
                    ious = [
                        metrics.get_ious(p, g, 0.5)
                        for p, g in zip(predictions, gt)
                    ]
                    total_iou.append(np.mean(ious))

                    # back prop
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                # log loss and iou
                avg_loss = np.mean(total_loss)
                avg_iou = np.mean(total_iou)

                logger_tb.update_value('train loss', avg_loss, count)
                logger_tb.update_value('train iou', avg_iou, count)

                # display segmentation on tensorboard
                if i == 0:
                    original = x_train[0].squeeze()
                    truth = y_train[0].squeeze()
                    seg = pred[0].cpu().squeeze().detach().numpy()

                    # TODO display segmentations based on number of ouput
                    logger_tb.update_image("truth", truth, count)
                    logger_tb.update_image("segmentation", seg, count)
                    logger_tb.update_image("original", original, count)

                    count += 1
                    progress_bar.update(len(x))

    # save model
    ckpt_dict = {
        'model_name': model.__class__.__name__,
        'model_args': model.args_dict(),
        'model_state': model.to('cpu').state_dict()
    }
    experiment_name = f"{model.__class__.__name__}_{args.dataset}_{train_dataset.target_dim}c"
    if args.dataset == "HPA":
        experiment_name += f"_{args.max_mean}"
    experiment_name += f"_{args.num_kernel}"
    ckpt_path = os.path.join(args.save_dir, f"{experiment_name}.pth")
    torch.save(ckpt_dict, ckpt_path)
Beispiel #4
0
        if cfg.TRAIN.track == 'simple':
            test_path = '../test_result/simple'
            base_path = '../../data/simple/test_raw'
        else:
            test_path = '../../data/complex/' + cfg.TEST.crop_way + '/test_result'
            base_path = '../../data/complex/' + cfg.TEST.crop_way + '/test_raw'
        model_name = cfg.TEST.model_name
        save_path = os.path.join(test_path, model_name, 'result')
        model_path = cfg.TRAIN.save_path
        model_path = os.path.join(model_path, model_name)

        if not os.path.exists(save_path):
            os.makedirs(save_path)
        
        thresd = cfg.TEST.thresd
        model = FusionNet(input_nc=cfg.TRAIN.input_nc,output_nc=cfg.TRAIN.output_nc,ngf=cfg.TRAIN.ngf)
        ckpt = 'model.ckpt'
        ckpt_path = os.path.join(model_path, ckpt)
        checkpoint = torch.load(ckpt_path)

        new_state_dict = OrderedDict()
        state_dict = checkpoint['model_weights']
        for k, v in state_dict.items():
            name = k[7:] # remove module.
            # name = k
            new_state_dict[name] = v

        model.load_state_dict(new_state_dict)
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        model = model.to(device)
Beispiel #5
0
def main(args):

    # tensorboard
    logger_tb = logger.Logger(log_dir=args.experiment_name)

    #augmenter = get_augmenter(args)

    # train dataloader and val dataset
    train_dataset = NucleiDataset(args.train_data, 'train', transform=True)
    val_dataset = NucleiDataset(args.val_data, 'val', transform=True)

    train_params = {
        'batch_size': args.batch_size,
        'shuffle': False,
        'num_workers': args.num_workers
    }

    train_dataloader = DataLoader(train_dataset, **train_params)

    # device
    device = torch.device(args.device)

    # model
    if args.model == "fusion":
        model = FusionNet(args, train_dataset.dim)
    elif args.model == "dilation":
        model = DilationCNN(train_dataset.dim)
    elif args.model == "unet":
        model = UNet(args.num_kernel, args.kernel_size, train_dataset.dim)

    if args.device == "cuda":
        # parse gpu_ids for data paralle
        if ',' in args.gpu_ids:
            gpu_ids = [int(ids) for ids in args.gpu_ids.split(',')]
        else:
            gpu_ids = int(args.gpu_ids)

        # parallelize computation
        if type(gpu_ids) is not int:
            model = nn.DataParallel(model, gpu_ids)
    model.to(device)

    # optimizer
    parameters = model.parameters()
    if args.optimizer == "adam":
        optimizer = torch.optim.Adam(parameters, args.lr)
    else:
        optimizer = torch.optim.SGD(parameters, args.lr)

    # loss
    loss_function = dice_loss

    # train model
    for epoch in range(args.epoch):
        model.train()

        with tqdm.tqdm(total=len(train_dataloader.dataset),
                       unit=f"epoch {epoch} itr") as progress_bar:
            total_loss = []
            total_iou = []
            total_precision = []
            for i, (x_train, y_train) in enumerate(train_dataloader):

                with torch.set_grad_enabled(True):

                    # send data and label to device
                    x = torch.Tensor(x_train.float()).to(device)
                    y = torch.Tensor(y_train.float()).to(device)

                    # predict segmentation
                    pred = model.forward(x)

                    # calculate loss
                    loss = loss_function(pred, y)
                    total_loss.append(loss.item())

                    # calculate IoU precision
                    predictions = pred.clone().squeeze().detach().cpu().numpy()
                    gt = y.clone().squeeze().detach().cpu().numpy()
                    ious = [
                        metrics.get_ious(p, g, 0.5)
                        for p, g in zip(predictions, gt)
                    ]
                    total_iou.append(np.mean(ious))

                    # back prop
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                # log loss and iou
                avg_loss = np.mean(total_loss)
                avg_iou = np.mean(total_iou)

                logger_tb.update_value('train loss', avg_loss, epoch)
                logger_tb.update_value('train iou', avg_iou, epoch)

                progress_bar.update(len(x))

        # validation
        model.eval()
        for idx in range(len(val_dataset)):
            x_val, y_val, mask_val = val_dataset.__getitem__(idx)

            total_precision = []
            total_iou = []
            total_loss = []
            with torch.no_grad():

                # send data and label to device
                x_val = np.expand_dims(x_val, axis=0)
                x = torch.Tensor(torch.tensor(x_val).float()).to(device)
                y = torch.Tensor(torch.tensor(y_val).float()).to(device)

                # predict segmentation
                pred = model.forward(x)

                # calculate loss
                loss = loss_function(pred, y)
                total_loss.append(loss.item())

                # calculate IoU
                prediction = pred.clone().squeeze().detach().cpu().numpy()
                gt = y.clone().squeeze().detach().cpu().numpy()
                iou = metrics.get_ious(prediction, gt, 0.5)
                total_iou.append(iou)

                # calculate precision
                precision = metrics.compute_precision(prediction, mask_val,
                                                      0.5)
                total_precision.append(precision)

                # display segmentation on tensorboard
                if idx == 1:
                    original = x_val
                    truth = np.expand_dims(y_val, axis=0)
                    seg = pred.cpu().squeeze().detach().numpy()
                    seg = np.expand_dims(seg, axis=0)

                    logger_tb.update_image("original", original, 0)
                    logger_tb.update_image("ground truth", truth, 0)
                    logger_tb.update_image("segmentation", seg, epoch)

        # log metrics
        logger_tb.update_value('val loss', np.mean(total_loss), epoch)
        logger_tb.update_value('val iou', np.mean(total_iou), epoch)
        logger_tb.update_value('val precision', np.mean(total_precision),
                               epoch)

    # save model
    ckpt_dict = {
        'model_name': model.__class__.__name__,
        'model_args': model.args_dict(),
        'model_state': model.to('cpu').state_dict()
    }
    ckpt_path = os.path.join(args.save_dir, f"{model.__class__.__name__}.pth")
    torch.save(ckpt_dict, ckpt_path)