Beispiel #1
0
def main():
    # os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    assert torch.cuda.is_available(), "Currently, we only support CUDA version"

    torch.manual_seed(args.seed)
    # torch.cuda.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    Network = getattr(models, args.net)  #
    model = Network(**args.net_params)
    model = torch.nn.DataParallel(model).to(device)
    optimizer = getattr(torch.optim, args.opt)(model.parameters(),
                                               **args.opt_params)
    criterion = getattr(criterions, args.criterion)

    msg = ''
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_iter = checkpoint['iter']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optim_dict'])
            msg = ("=> loaded checkpoint '{}' (iter {})".format(
                args.resume, checkpoint['iter']))
        else:
            msg = "=> no checkpoint found at '{}'".format(args.resume)
    else:
        msg = '-------------- New training session ----------------'

    msg += '\n' + str(args)
    logging.info(msg)

    Dataset = getattr(datasets, args.dataset)  #

    if args.prefix_path:
        args.train_data_dir = os.path.join(args.prefix_path,
                                           args.train_data_dir)
    train_list = os.path.join(args.train_data_dir, args.train_list)
    train_set = Dataset(train_list,
                        root=args.train_data_dir,
                        for_train=True,
                        transforms=args.train_transforms)

    num_iters = args.num_iters or (len(train_set) *
                                   args.num_epochs) // args.batch_size
    num_iters -= args.start_iter
    train_sampler = CycleSampler(len(train_set), num_iters * args.batch_size)
    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              collate_fn=train_set.collate,
                              sampler=train_sampler,
                              num_workers=args.workers,
                              pin_memory=True,
                              worker_init_fn=init_fn)

    if args.valid_list:
        valid_list = os.path.join(args.train_data_dir, args.valid_list)
        valid_set = Dataset(valid_list,
                            root=args.train_data_dir,
                            for_train=False,
                            transforms=args.test_transforms)

        valid_loader = DataLoader(valid_set,
                                  batch_size=1,
                                  shuffle=False,
                                  collate_fn=valid_set.collate,
                                  num_workers=args.workers,
                                  pin_memory=True)

    start = time.time()

    enum_batches = len(train_set) / float(
        args.batch_size)  # nums_batch per epoch
    args.schedule = {
        int(k * enum_batches): v
        for k, v in args.schedule.items()
    }  # 17100
    # args.save_freq = int(enum_batches * args.save_freq)
    # args.valid_freq = int(enum_batches * args.valid_freq)

    losses = AverageMeter()
    torch.set_grad_enabled(True)

    for i, data in enumerate(train_loader, args.start_iter):

        elapsed_bsize = int(i / enum_batches) + 1
        epoch = int((i + 1) / enum_batches)
        setproctitle.setproctitle("Epoch:{}/{}".format(elapsed_bsize,
                                                       args.num_epochs))

        adjust_learning_rate(optimizer, epoch, args.num_epochs,
                             args.opt_params.lr)

        # data = [t.cuda(non_blocking=True) for t in data]
        data = [t.to(device) for t in data]
        x, target = data[:2]

        output = model(x)
        if not args.weight_type:  # compatible for the old version
            args.weight_type = 'square'

        # loss = criterion(output, target, args.eps,args.weight_type)
        # loss = criterion(output, target,args.alpha,args.gamma) # for focal loss
        loss = criterion(output, target, *args.kwargs)

        # measure accuracy and record loss
        losses.update(loss.item(), target.numel())

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % int(enum_batches * args.save_freq) == 0 \
                or (i + 1) % int(enum_batches * (args.num_epochs - 1)) == 0 \
                or (i + 1) % int(enum_batches * (args.num_epochs - 2)) == 0 \
                or (i + 1) % int(enum_batches * (args.num_epochs - 3)) == 0 \
                or (i + 1) % int(enum_batches * (args.num_epochs - 4)) == 0:
            file_name = os.path.join(ckpts, 'model_epoch_{}.pth'.format(epoch))
            torch.save(
                {
                    'iter': i + 1,
                    'state_dict': model.state_dict(),
                    'optim_dict': optimizer.state_dict(),
                }, file_name)

        # validation
        if (i + 1) % int(enum_batches * args.valid_freq) == 0 \
                or (i + 1) % int(enum_batches * (args.num_epochs - 1)) == 0 \
                or (i + 1) % int(enum_batches * (args.num_epochs - 2)) == 0 \
                or (i + 1) % int(enum_batches * (args.num_epochs - 3)) == 0 \
                or (i + 1) % int(enum_batches * (args.num_epochs - 4)) == 0:
            logging.info('-' * 50)
            msg = 'Iter {}, Epoch {:.4f}, {}'.format(i, i / enum_batches,
                                                     'validation')
            logging.info(msg)
            with torch.no_grad():
                validate_softmax(valid_loader,
                                 model,
                                 cfg=args.cfg,
                                 savepath='',
                                 names=valid_set.names,
                                 scoring=True,
                                 verbose=False,
                                 use_TTA=False,
                                 snapshot=False,
                                 postprocess=False,
                                 cpu_only=False)

        msg = 'Iter {0:}, Epoch {1:.4f}, Loss {2:.7f}'.format(
            i + 1, (i + 1) / enum_batches, losses.avg)

        logging.info(msg)
        losses.reset()

    i = num_iters + args.start_iter
    file_name = os.path.join(ckpts, 'model_last.pth')
    torch.save(
        {
            'iter': i,
            'state_dict': model.state_dict(),
            'optim_dict': optimizer.state_dict(),
        }, file_name)

    msg = 'total time: {:.4f} minutes'.format((time.time() - start) / 60)
    logging.info(msg)
Beispiel #2
0
def main():

    torch.manual_seed(config.seed)
    torch.cuda.manual_seed(config.seed)
    random.seed(config.seed)
    np.random.seed(config.seed)

    #model = generate_model(config)
    model = getattr(models, config.model_name)()
    # model = getattr(models, config.model_name)(c=4,n=32,channels=128, groups=16,norm='sync_bn', num_classes=4,output_func='softmax')
    model = torch.nn.DataParallel(model).cuda()

    load_file = os.path.join(os.path.abspath(os.path.dirname(__file__)),
                             'checkpoint',
                             config.experiment + config.test_date,
                             config.test_file)

    if os.path.exists(load_file):
        checkpoint = torch.load(load_file)
        model.load_state_dict(checkpoint['state_dict'])
        config.start_epoch = checkpoint['epoch']
        print('Successfully load checkpoint {}'.format(
            os.path.join(config.experiment + config.test_date,
                         config.test_file)))
    else:
        print('There is no resume file to load!')

    valid_list = os.path.join(config.root, config.valid_dir, config.valid_file)
    valid_root = os.path.join(config.root, config.valid_dir)
    valid_set = BraTS(valid_list, valid_root, mode='test')
    print('Samples for valid = {}'.format(len(valid_set)))

    valid_loader = DataLoader(valid_set,
                              batch_size=1,
                              shuffle=False,
                              num_workers=config.num_workers,
                              pin_memory=True)

    submission = os.path.join(os.path.abspath(os.path.dirname(__file__)),
                              config.output_dir, config.submission,
                              config.experiment + config.test_date)
    visual = os.path.join(os.path.abspath(os.path.dirname(__file__)),
                          config.output_dir, config.visual,
                          config.experiment + config.test_date)
    if not os.path.exists(submission):
        os.makedirs(submission)
    if not os.path.exists(visual):
        os.makedirs(visual)

    start_time = time.time()

    with torch.no_grad():
        validate_softmax(valid_loader=valid_loader,
                         model=model,
                         savepath=submission,
                         visual=visual,
                         names=valid_set.names,
                         scoring=False,
                         use_TTA=config.use_TTA,
                         save_format=config.save_format,
                         snapshot=True,
                         postprocess=True)

    end_time = time.time()
    full_test_time = (end_time - start_time) / 60
    average_time = full_test_time / len(valid_set)
    print('{:.2f} minutes!'.format(average_time))
Beispiel #3
0
def main():
    # setup environments and seeds
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    assert torch.cuda.is_available(), "Currently, we only support CUDA version"

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    Network = getattr(models, args.net)  #
    model = Network(**args.net_params)

    model = torch.nn.DataParallel(model).cuda()
    print(args.resume)
    assert os.path.isfile(args.resume), "no checkpoint found at {}".format(
        args.resume)
    print("=> loading checkpoint '{}'".format(args.resume))
    checkpoint = torch.load(args.resume)
    args.start_iter = checkpoint['iter']
    model.load_state_dict(checkpoint['state_dict'])
    msg = ("=> loaded checkpoint '{}' (iter {})".format(
        args.resume, checkpoint['iter']))

    msg += '\n' + str(args)
    logging.info(msg)

    if args.mode == 0:
        root_path = args.train_data_dir
        is_scoring = True
    elif args.mode == 1:
        root_path = args.valid_data_dir
        is_scoring = False
    elif args.mode == 2:
        root_path = args.test_data_dir
        is_scoring = False
    else:
        raise ValueError

    Dataset = getattr(datasets, args.dataset)  #
    valid_list = os.path.join(root_path, args.test_list)
    valid_set = Dataset(valid_list,
                        root=root_path,
                        for_train=False,
                        transforms=args.test_transforms)

    valid_loader = DataLoader(valid_set,
                              batch_size=1,
                              shuffle=False,
                              collate_fn=valid_set.collate,
                              num_workers=10,
                              pin_memory=True)

    if args.is_out:
        out_dir = './output/{}'.format(args.cfg)
        os.makedirs(os.path.join(out_dir, 'submission'), exist_ok=True)
        os.makedirs(os.path.join(out_dir, 'snapshot'), exist_ok=True)
    else:
        out_dir = ''

    logging.info('-' * 50)
    logging.info(msg)

    with torch.no_grad():
        validate_softmax(valid_loader,
                         model,
                         cfg=args.cfg,
                         savepath=out_dir,
                         save_format=args.save_format,
                         names=valid_set.names,
                         scoring=is_scoring,
                         verbose=args.verbose,
                         use_TTA=args.use_TTA,
                         snapshot=args.snapshot,
                         postprocess=args.postprocess,
                         cpu_only=False)
Beispiel #4
0
def main():

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    _, model = TransBTS(dataset='brats', _conv_repr=True, _pe_type="learned")

    model = torch.nn.DataParallel(model).cuda()

    load_file = os.path.join(os.path.abspath(os.path.dirname(__file__)),
                             'checkpoint', args.experiment + args.test_date,
                             args.test_file)

    if os.path.exists(load_file):
        checkpoint = torch.load(load_file)
        model.load_state_dict(checkpoint['state_dict'])
        args.start_epoch = checkpoint['epoch']
        print('Successfully load checkpoint {}'.format(
            os.path.join(args.experiment + args.test_date, args.test_file)))
    else:
        print('There is no resume file to load!')

    valid_list = os.path.join(args.root, args.valid_dir, args.valid_file)
    valid_root = os.path.join(args.root, args.valid_dir)
    valid_set = BraTS(valid_list, valid_root, mode='test')
    print('Samples for valid = {}'.format(len(valid_set)))

    valid_loader = DataLoader(valid_set,
                              batch_size=1,
                              shuffle=False,
                              num_workers=args.num_workers,
                              pin_memory=True)

    submission = os.path.join(os.path.abspath(os.path.dirname(__file__)),
                              args.output_dir, args.submission,
                              args.experiment + args.test_date)
    visual = os.path.join(os.path.abspath(os.path.dirname(__file__)),
                          args.output_dir, args.visual,
                          args.experiment + args.test_date)

    if not os.path.exists(submission):
        os.makedirs(submission)
    if not os.path.exists(visual):
        os.makedirs(visual)

    start_time = time.time()

    with torch.no_grad():
        validate_softmax(valid_loader=valid_loader,
                         model=model,
                         load_file=load_file,
                         multimodel=False,
                         savepath=submission,
                         visual=visual,
                         names=valid_set.names,
                         use_TTA=args.use_TTA,
                         save_format=args.save_format,
                         snapshot=True,
                         postprocess=True)

    end_time = time.time()
    full_test_time = (end_time - start_time) / 60
    average_time = full_test_time / len(valid_set)
    print('{:.2f} minutes!'.format(average_time))