Beispiel #1
0
def main():
    log_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'log',
                           config.experiment + config.data)
    log_file = log_dir + '.txt'
    log_config(log_file)
    logging.info(
        '-------------------------------------------This is all configurations-----------------------------------------'
    )
    logging.info(config)
    logging.info(
        '---------------------------------------------This is a halving line-------------------------------------------'
    )
    logging.info('{}'.format(config.description))

    torch.manual_seed(config.seed)
    torch.cuda.manual_seed(config.seed)
    random.seed(config.seed)
    np.random.seed(config.seed)

    #model = generate_model(config)
    model = getattr(models, config.model_name)()
    #model = getattr(models, config.model_name)(c=4,n=32,channels=128, groups=16,norm='sync_bn', num_classes=4,output_func='softmax')
    model = torch.nn.DataParallel(model).cuda()
    model.train()

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=config.lr,
                                 weight_decay=config.weight_decay,
                                 amsgrad=config.amsgrad)
    # criterion = getattr(criterions, config.criterion)
    criterion = torch.nn.CrossEntropyLoss()

    checkpoint_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)),
                                  'checkpoint',
                                  config.experiment + config.data)
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    resume = os.path.join(os.path.abspath(os.path.dirname(__file__)),
                          config.resume)
    if os.path.isfile(resume) and config.load:
        logging.info('loading checkpoint {}'.format(resume))
        checkpoint = torch.load(resume)
        config.start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optim_dict'])
        logging.info(
            'Successfully loading checkpoint {} and training from epoch: {}'.
            format(config.resume, config.start_epoch))

    else:
        logging.info('re-training!!!')

    train_list = os.path.join(config.root, config.train_dir, config.train_file)
    train_root = os.path.join(config.root, config.train_dir)

    train_set = BraTS(train_list, train_root, config.mode)
    logging.info('Samples for train = {}'.format(len(train_set)))

    num_iters = (len(train_set) * config.end_epoch) // config.batch_size
    num_iters -= (len(train_set) * config.start_epoch) // config.batch_size
    train_loader = DataLoader(dataset=train_set,
                              shuffle=True,
                              batch_size=config.batch_size,
                              num_workers=config.num_workers,
                              pin_memory=True)

    start_time = time.time()

    torch.set_grad_enabled(True)

    for epoch in range(config.start_epoch, config.end_epoch):
        loss_epoch = []
        area1 = []
        area2 = []
        area4 = []

        setproctitle.setproctitle('{}:{} {}/{}'.format(config.user,
                                                       config.model_name,
                                                       epoch + 1,
                                                       config.end_epoch))
        start_epoch = time.time()

        for i, data in enumerate(train_loader):

            adjust_learning_rate(optimizer, epoch, config.end_epoch, config.lr)
            #warm_up_learning_rate_adjust2(config.lr, epoch, config.warm, config.end_epoch, optimizer)
            data = [t.cuda(non_blocking=True) for t in data]
            x, target = data
            output = model(x)

            target[target == 4] = 3

            loss = criterion(output, target)
            logging.info('Epoch: {}_Iter:{}  loss: {:.5f} ||'.format(
                epoch, i, loss))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        end_epoch = time.time()

        if (epoch + 1) % int(config.save_freq) == 0 \
                or (epoch + 1) % int(config.end_epoch - 1) == 0 \
                or (epoch + 1) % int(config.end_epoch - 2) == 0:
            file_name = os.path.join(checkpoint_dir,
                                     'model_epoch_{}.pth'.format(epoch))
            torch.save(
                {
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optim_dict': optimizer.state_dict(),
                }, file_name)

        epoch_time_minute = (end_epoch - start_epoch) / 60
        remaining_time_hour = (config.end_epoch - epoch -
                               1) * epoch_time_minute / 60
        logging.info('Current epoch time consumption: {:.2f} minutes!'.format(
            epoch_time_minute))
        logging.info('Estimated remaining training time: {:.2f} hours!'.format(
            remaining_time_hour))

    final_name = os.path.join(checkpoint_dir, 'model_epoch_last.pth')
    torch.save(
        {
            'epoch': config.end_epoch,
            'state_dict': model.state_dict(),
            'optim_dict': optimizer.state_dict(),
        }, final_name)
    end_time = time.time()
    total_time = (end_time - start_time) / 3600
    logging.info('The total training time is {:.2f} hours'.format(total_time))

    logging.info(
        '-----------------------------------The training process finished!------------------------------------'
    )
Beispiel #2
0
def main_worker():
    if args.local_rank == 0:
        log_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'log', args.experiment+args.date)
        log_file = log_dir + '.txt'
        log_args(log_file)
        logging.info('--------------------------------------This is all argsurations----------------------------------')
        for arg in vars(args):
            logging.info('{}={}'.format(arg, getattr(args, arg)))
        logging.info('----------------------------------------This is a halving line----------------------------------')
        logging.info('{}'.format(args.description))

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.distributed.init_process_group('nccl')
    torch.cuda.set_device(args.local_rank)

    _, model = TransBTS(dataset='brats', _conv_repr=True, _pe_type="learned")

    model.cuda(args.local_rank)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank,
                                                find_unused_parameters=True)
    model.train()

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, amsgrad=args.amsgrad)


    criterion = getattr(criterions, args.criterion)

    if args.local_rank == 0:
        checkpoint_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'checkpoint', args.experiment+args.date)
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

    resume = ''

    writer = SummaryWriter()

    if os.path.isfile(resume) and args.load:
        logging.info('loading checkpoint {}'.format(resume))
        checkpoint = torch.load(resume, map_location=lambda storage, loc: storage)

        model.load_state_dict(checkpoint['state_dict'])

        logging.info('Successfully loading checkpoint {} and training from epoch: {}'
                     .format(args.resume, args.start_epoch))
    else:
        logging.info('re-training!!!')

    train_list = os.path.join(args.root, args.train_dir, args.train_file)
    train_root = os.path.join(args.root, args.train_dir)

    train_set = BraTS(train_list, train_root, args.mode)
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_set)
    logging.info('Samples for train = {}'.format(len(train_set)))


    num_gpu = (len(args.gpu)+1) // 2

    train_loader = DataLoader(dataset=train_set, sampler=train_sampler, batch_size=args.batch_size // num_gpu,
                              drop_last=True, num_workers=args.num_workers, pin_memory=True)

    start_time = time.time()

    torch.set_grad_enabled(True)

    for epoch in range(args.start_epoch, args.end_epoch):
        train_sampler.set_epoch(epoch)  # shuffle
        setproctitle.setproctitle('{}: {}/{}'.format(args.user, epoch+1, args.end_epoch))
        start_epoch = time.time()

        for i, data in enumerate(train_loader):

            adjust_learning_rate(optimizer, epoch, args.end_epoch, args.lr)

            x, target = data
            x = x.cuda(args.local_rank, non_blocking=True)
            target = target.cuda(args.local_rank, non_blocking=True)


            output = model(x)

            loss, loss1, loss2, loss3 = criterion(output, target)
            reduce_loss = all_reduce_tensor(loss, world_size=num_gpu).data.cpu().numpy()
            reduce_loss1 = all_reduce_tensor(loss1, world_size=num_gpu).data.cpu().numpy()
            reduce_loss2 = all_reduce_tensor(loss2, world_size=num_gpu).data.cpu().numpy()
            reduce_loss3 = all_reduce_tensor(loss3, world_size=num_gpu).data.cpu().numpy()

            if args.local_rank == 0:
                logging.info('Epoch: {}_Iter:{}  loss: {:.5f} || 1:{:.4f} | 2:{:.4f} | 3:{:.4f} ||'
                             .format(epoch, i, reduce_loss, reduce_loss1, reduce_loss2, reduce_loss3))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        end_epoch = time.time()
        if args.local_rank == 0:
            if (epoch + 1) % int(args.save_freq) == 0 \
                    or (epoch + 1) % int(args.end_epoch - 1) == 0 \
                    or (epoch + 1) % int(args.end_epoch - 2) == 0 \
                    or (epoch + 1) % int(args.end_epoch - 3) == 0:
                file_name = os.path.join(checkpoint_dir, 'model_epoch_{}.pth'.format(epoch))
                torch.save({
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optim_dict': optimizer.state_dict(),
                },
                    file_name)

            writer.add_scalar('lr:', optimizer.param_groups[0]['lr'], epoch)
            writer.add_scalar('loss:', reduce_loss, epoch)
            writer.add_scalar('loss1:', reduce_loss1, epoch)
            writer.add_scalar('loss2:', reduce_loss2, epoch)
            writer.add_scalar('loss3:', reduce_loss3, epoch)

        if args.local_rank == 0:
            epoch_time_minute = (end_epoch-start_epoch)/60
            remaining_time_hour = (args.end_epoch-epoch-1)*epoch_time_minute/60
            logging.info('Current epoch time consumption: {:.2f} minutes!'.format(epoch_time_minute))
            logging.info('Estimated remaining training time: {:.2f} hours!'.format(remaining_time_hour))

    if args.local_rank == 0:
        writer.close()

        final_name = os.path.join(checkpoint_dir, 'model_epoch_last.pth')
        torch.save({
            'epoch': args.end_epoch,
            'state_dict': model.state_dict(),
            'optim_dict': optimizer.state_dict(),
        },
            final_name)
    end_time = time.time()
    total_time = (end_time-start_time)/3600
    logging.info('The total training time is {:.2f} hours'.format(total_time))

    logging.info('----------------------------------The training process finished!-----------------------------------')
Beispiel #3
0
def main():
    setattr(config, config.mode, 'train_fold')
    setattr(config, config.valid_file, 'valid_small.txt')
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed(config.seed)
    random.seed(config.seed)
    np.random.seed(config.seed)

    model = getattr(models, config.model_name)(c=4,
                                               n=32,
                                               channels=128,
                                               groups=16,
                                               norm='sync_bn',
                                               num_classes=4,
                                               output_func='softmax')
    model = torch.nn.DataParallel(model).cuda()

    load_file = os.path.join(os.path.abspath(os.path.dirname(__file__)),
                             'checkpoint',
                             config.experiment + config.test_date,
                             config.test_file)

    if os.path.exists(load_file):
        checkpoint = torch.load(load_file)
        model.load_state_dict(checkpoint['state_dict'])
        config.start_epoch = checkpoint['epoch']
        print('Successfully load checkpoint {}'.format(
            os.path.join(config.experiment + config.test_date,
                         config.test_file)))
    else:
        print('There is no resume file to load!')

    valid_list = os.path.join(config.root, config.train_dir, 'valid_small.txt')
    valid_root = os.path.join(config.root, config.train_dir)
    valid_set = BraTS(valid_list, valid_root, mode='train_fold')
    print('Samples for valid = {}'.format(len(valid_set)))

    valid_loader = DataLoader(valid_set,
                              batch_size=1,
                              shuffle=False,
                              num_workers=config.num_workers,
                              pin_memory=True)

    submission = os.path.join(os.path.abspath(os.path.dirname(__file__)),
                              config.output_dir, config.submission,
                              config.experiment + config.test_date)
    visual = os.path.join(os.path.abspath(os.path.dirname(__file__)),
                          config.output_dir, 'visual_fold',
                          config.experiment + config.test_date)
    if not os.path.exists(submission):
        os.makedirs(submission)
    if not os.path.exists(visual):
        os.makedirs(visual)

    start_time = time.time()

    with torch.no_grad():
        validate_softmax(valid_loader=valid_loader,
                         model=model,
                         savepath=submission,
                         visual=visual,
                         names=valid_set.names,
                         scoring=False,
                         use_TTA=False,
                         save_format=config.save_format,
                         postprocess=True,
                         snapshot=True)

    end_time = time.time()
    full_test_time = (end_time - start_time) / 60
    average_time = full_test_time / len(valid_set)
    print('{:.2f} minutes!'.format(average_time))
Beispiel #4
0
def main():

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    _, model = TransBTS(dataset='brats', _conv_repr=True, _pe_type="learned")

    model = torch.nn.DataParallel(model).cuda()

    load_file = os.path.join(os.path.abspath(os.path.dirname(__file__)),
                             'checkpoint', args.experiment + args.test_date,
                             args.test_file)

    if os.path.exists(load_file):
        checkpoint = torch.load(load_file)
        model.load_state_dict(checkpoint['state_dict'])
        args.start_epoch = checkpoint['epoch']
        print('Successfully load checkpoint {}'.format(
            os.path.join(args.experiment + args.test_date, args.test_file)))
    else:
        print('There is no resume file to load!')

    valid_list = os.path.join(args.root, args.valid_dir, args.valid_file)
    valid_root = os.path.join(args.root, args.valid_dir)
    valid_set = BraTS(valid_list, valid_root, mode='test')
    print('Samples for valid = {}'.format(len(valid_set)))

    valid_loader = DataLoader(valid_set,
                              batch_size=1,
                              shuffle=False,
                              num_workers=args.num_workers,
                              pin_memory=True)

    submission = os.path.join(os.path.abspath(os.path.dirname(__file__)),
                              args.output_dir, args.submission,
                              args.experiment + args.test_date)
    visual = os.path.join(os.path.abspath(os.path.dirname(__file__)),
                          args.output_dir, args.visual,
                          args.experiment + args.test_date)

    if not os.path.exists(submission):
        os.makedirs(submission)
    if not os.path.exists(visual):
        os.makedirs(visual)

    start_time = time.time()

    with torch.no_grad():
        validate_softmax(valid_loader=valid_loader,
                         model=model,
                         load_file=load_file,
                         multimodel=False,
                         savepath=submission,
                         visual=visual,
                         names=valid_set.names,
                         use_TTA=args.use_TTA,
                         save_format=args.save_format,
                         snapshot=True,
                         postprocess=True)

    end_time = time.time()
    full_test_time = (end_time - start_time) / 60
    average_time = full_test_time / len(valid_set)
    print('{:.2f} minutes!'.format(average_time))