def main():
    # setup environments and seeds
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    # setup networks
    Network = getattr(models, args.net)
    model = Network(**args.net_params)
    model = model.cuda()

    optimizer = getattr(torch.optim, args.opt)(model.parameters(),
                                               **args.opt_params)
    criterion = getattr(criterions, args.criterion)

    msg = ''
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_iter = checkpoint['iter']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optim_dict'])
            msg = ("=> loaded checkpoint '{}' (iter {})".format(
                args.resume, checkpoint['iter']))
        else:
            msg = "=> no checkpoint found at '{}'".format(args.resume)
    else:
        msg = '-------------- New training session ----------------'

    msg += '\n' + str(args)
    logging.info(msg)

    # Data loading code
    Dataset = getattr(datasets, args.dataset)

    # The loader will get 1000 patches from 50 subjects for each sub epoch
    # each subject sample 20 patches
    train_list = os.path.join(args.data_dir, args.train_list)
    train_set = Dataset(train_list,
                        root=args.data_dir,
                        for_train=True,
                        num_patches=args.num_patches,
                        transforms=args.train_transforms,
                        sample_size=args.sample_size,
                        sub_sample_size=args.sub_sample_size,
                        target_size=args.target_size)

    num_iters = args.num_iters or (len(train_set) *
                                   args.num_epochs) // args.batch_size
    num_iters -= args.start_iter
    train_sampler = CycleSampler(len(train_set), num_iters * args.batch_size)
    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              collate_fn=train_set.collate,
                              sampler=train_sampler,
                              num_workers=args.workers,
                              pin_memory=True,
                              worker_init_fn=init_fn)

    if args.valid_list:
        valid_list = os.path.join(args.data_dir, args.valid_list)
        valid_set = Dataset(valid_list,
                            root=args.data_dir,
                            for_train=False,
                            crop=False,
                            transforms=args.test_transforms,
                            sample_size=args.sample_size,
                            sub_sample_size=args.sub_sample_size,
                            target_size=args.target_size)
        valid_loader = DataLoader(valid_set,
                                  batch_size=1,
                                  shuffle=False,
                                  collate_fn=valid_set.collate,
                                  num_workers=4,
                                  pin_memory=True)

        train_valid_set = Dataset(train_list,
                                  root=args.data_dir,
                                  for_train=False,
                                  crop=False,
                                  transforms=args.test_transforms,
                                  sample_size=args.sample_size,
                                  sub_sample_size=args.sub_sample_size,
                                  target_size=args.target_size)
        train_valid_loader = DataLoader(train_valid_set,
                                        batch_size=1,
                                        shuffle=False,
                                        collate_fn=train_valid_set.collate,
                                        num_workers=4,
                                        pin_memory=True)

    start = time.time()

    enum_batches = len(train_set) / float(args.batch_size)
    args.schedule = {
        int(k * enum_batches): v
        for k, v in args.schedule.items()
    }
    args.save_freq = int(enum_batches * args.save_freq)
    args.valid_freq = int(enum_batches * args.valid_freq)

    losses = AverageMeter()
    torch.set_grad_enabled(True)

    for i, (data, label) in enumerate(train_loader, args.start_iter):

        ## validation
        #if args.valid_list and  (i % args.valid_freq) == 0:
        #    logging.info('-'*50)
        #    msg  =  'Iter {}, Epoch {:.4f}, {}'.format(i, i/enum_batches, 'validation')
        #    logging.info(msg)
        #    with torch.no_grad():
        #        validate(valid_loader, model, batch_size=args.mini_batch_size, names=valid_set.names)

        # actual training
        adjust_learning_rate(optimizer, i)
        for data in zip(*[d.split(args.mini_batch_size) for d in data]):

            data = [t.cuda(non_blocking=True) for t in data]
            x1, x2, target = data[:3]

            if len(data) > 3:  # has mask
                m1, m2 = data[3:]
                x1 = add_mask(x1, m1, 1)
                x2 = add_mask(x2, m2, 1)

            # compute output
            output = model((x1, x2))  # output nx5x9x9x9, target nx9x9x9
            loss = criterion(output, target, args.alpha)

            # measure accuracy and record loss
            losses.update(loss.item(), target.numel())

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if (i + 1) % args.save_freq == 0:
            epoch = int((i + 1) // enum_batches)

            file_name = os.path.join(ckpts, 'model_epoch_{}.tar'.format(epoch))
            torch.save(
                {
                    'iter': i + 1,
                    'state_dict': model.state_dict(),
                    'optim_dict': optimizer.state_dict(),
                }, file_name)

        msg = 'Iter {0:}, Epoch {1:.4f}, Loss {2:.4f}'.format(
            i + 1, (i + 1) / enum_batches, losses.avg)
        logging.info(msg)

        losses.reset()

    i = num_iters + args.start_iter
    file_name = os.path.join(ckpts, 'model_last.tar')
    torch.save(
        {
            'iter': i,
            'state_dict': model.state_dict(),
            'optim_dict': optimizer.state_dict(),
        }, file_name)

    if args.valid_list:
        logging.info('-' * 50)
        msg = 'Iter {}, Epoch {:.4f}, {}'.format(i, i / enum_batches,
                                                 'validate validation data')
        logging.info(msg)

        with torch.no_grad():
            validate(valid_loader,
                     model,
                     batch_size=args.mini_batch_size,
                     names=valid_set.names,
                     out_dir=args.out)

        #logging.info('-'*50)
        #msg  =  'Iter {}, Epoch {:.4f}, {}'.format(i, i/enum_batches, 'validate training data')
        #logging.info(msg)

        #with torch.no_grad():
        #    validate(train_valid_loader, model, batch_size=args.mini_batch_size, names=train_valid_set.names, verbose=False)

    msg = 'total time: {:.4f} minutes'.format((time.time() - start) / 60)
    logging.info(msg)
Beispiel #2
0
def main():
    # os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    assert torch.cuda.is_available(), "Currently, we only support CUDA version"

    torch.manual_seed(args.seed)
    # torch.cuda.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    Network = getattr(models, args.net)  #
    model = Network(**args.net_params)
    model = torch.nn.DataParallel(model).to(device)
    optimizer = getattr(torch.optim, args.opt)(model.parameters(),
                                               **args.opt_params)
    criterion = getattr(criterions, args.criterion)

    msg = ''
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_iter = checkpoint['iter']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optim_dict'])
            msg = ("=> loaded checkpoint '{}' (iter {})".format(
                args.resume, checkpoint['iter']))
        else:
            msg = "=> no checkpoint found at '{}'".format(args.resume)
    else:
        msg = '-------------- New training session ----------------'

    msg += '\n' + str(args)
    logging.info(msg)

    Dataset = getattr(datasets, args.dataset)  #

    if args.prefix_path:
        args.train_data_dir = os.path.join(args.prefix_path,
                                           args.train_data_dir)
    train_list = os.path.join(args.train_data_dir, args.train_list)
    train_set = Dataset(train_list,
                        root=args.train_data_dir,
                        for_train=True,
                        transforms=args.train_transforms)

    num_iters = args.num_iters or (len(train_set) *
                                   args.num_epochs) // args.batch_size
    num_iters -= args.start_iter
    train_sampler = CycleSampler(len(train_set), num_iters * args.batch_size)
    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              collate_fn=train_set.collate,
                              sampler=train_sampler,
                              num_workers=args.workers,
                              pin_memory=True,
                              worker_init_fn=init_fn)

    if args.valid_list:
        valid_list = os.path.join(args.train_data_dir, args.valid_list)
        valid_set = Dataset(valid_list,
                            root=args.train_data_dir,
                            for_train=False,
                            transforms=args.test_transforms)

        valid_loader = DataLoader(valid_set,
                                  batch_size=1,
                                  shuffle=False,
                                  collate_fn=valid_set.collate,
                                  num_workers=args.workers,
                                  pin_memory=True)

    start = time.time()

    enum_batches = len(train_set) / float(
        args.batch_size)  # nums_batch per epoch
    args.schedule = {
        int(k * enum_batches): v
        for k, v in args.schedule.items()
    }  # 17100
    # args.save_freq = int(enum_batches * args.save_freq)
    # args.valid_freq = int(enum_batches * args.valid_freq)

    losses = AverageMeter()
    torch.set_grad_enabled(True)

    for i, data in enumerate(train_loader, args.start_iter):

        elapsed_bsize = int(i / enum_batches) + 1
        epoch = int((i + 1) / enum_batches)
        setproctitle.setproctitle("Epoch:{}/{}".format(elapsed_bsize,
                                                       args.num_epochs))

        adjust_learning_rate(optimizer, epoch, args.num_epochs,
                             args.opt_params.lr)

        # data = [t.cuda(non_blocking=True) for t in data]
        data = [t.to(device) for t in data]
        x, target = data[:2]

        output = model(x)
        if not args.weight_type:  # compatible for the old version
            args.weight_type = 'square'

        # loss = criterion(output, target, args.eps,args.weight_type)
        # loss = criterion(output, target,args.alpha,args.gamma) # for focal loss
        loss = criterion(output, target, *args.kwargs)

        # measure accuracy and record loss
        losses.update(loss.item(), target.numel())

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % int(enum_batches * args.save_freq) == 0 \
                or (i + 1) % int(enum_batches * (args.num_epochs - 1)) == 0 \
                or (i + 1) % int(enum_batches * (args.num_epochs - 2)) == 0 \
                or (i + 1) % int(enum_batches * (args.num_epochs - 3)) == 0 \
                or (i + 1) % int(enum_batches * (args.num_epochs - 4)) == 0:
            file_name = os.path.join(ckpts, 'model_epoch_{}.pth'.format(epoch))
            torch.save(
                {
                    'iter': i + 1,
                    'state_dict': model.state_dict(),
                    'optim_dict': optimizer.state_dict(),
                }, file_name)

        # validation
        if (i + 1) % int(enum_batches * args.valid_freq) == 0 \
                or (i + 1) % int(enum_batches * (args.num_epochs - 1)) == 0 \
                or (i + 1) % int(enum_batches * (args.num_epochs - 2)) == 0 \
                or (i + 1) % int(enum_batches * (args.num_epochs - 3)) == 0 \
                or (i + 1) % int(enum_batches * (args.num_epochs - 4)) == 0:
            logging.info('-' * 50)
            msg = 'Iter {}, Epoch {:.4f}, {}'.format(i, i / enum_batches,
                                                     'validation')
            logging.info(msg)
            with torch.no_grad():
                validate_softmax(valid_loader,
                                 model,
                                 cfg=args.cfg,
                                 savepath='',
                                 names=valid_set.names,
                                 scoring=True,
                                 verbose=False,
                                 use_TTA=False,
                                 snapshot=False,
                                 postprocess=False,
                                 cpu_only=False)

        msg = 'Iter {0:}, Epoch {1:.4f}, Loss {2:.7f}'.format(
            i + 1, (i + 1) / enum_batches, losses.avg)

        logging.info(msg)
        losses.reset()

    i = num_iters + args.start_iter
    file_name = os.path.join(ckpts, 'model_last.pth')
    torch.save(
        {
            'iter': i,
            'state_dict': model.state_dict(),
            'optim_dict': optimizer.state_dict(),
        }, file_name)

    msg = 'total time: {:.4f} minutes'.format((time.time() - start) / 60)
    logging.info(msg)
Beispiel #3
0
def main():
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    assert torch.cuda.is_available(), "Currently, we only support CUDA version"
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    Network = getattr(models, args.net)  #
    model = Network(**args.net_params)
    model = torch.nn.DataParallel(model).cuda()

    optimizer = getattr(torch.optim, args.opt)(model.parameters(),
                                               **args.opt_params)
    criterion = getattr(criterions, args.criterion)

    msg = ''
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_iter = checkpoint['iter']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optim_dict'])
            msg = ("=> loaded checkpoint '{}' (iter {})".format(
                args.resume, checkpoint['iter']))
        else:
            msg = "=> no checkpoint found at '{}'".format(args.resume)
    else:
        msg = '-------------- New training session ----------------'

    msg += '\n' + str(args)
    logging.info(msg)

    # Data loading code
    Dataset = getattr(datasets, args.dataset)  #

    train_list = os.path.join(args.train_data_dir, args.train_list)
    train_set = Dataset(train_list,
                        root=args.train_data_dir,
                        for_train=True,
                        transforms=args.train_transforms)

    num_iters = args.num_iters or (len(train_set) *
                                   args.num_epochs) // args.batch_size
    num_iters -= args.start_iter
    train_sampler = CycleSampler(len(train_set), num_iters * args.batch_size)
    train_loader = DataLoader(dataset=train_set,
                              batch_size=args.batch_size,
                              collate_fn=train_set.collate,
                              sampler=train_sampler,
                              num_workers=args.workers,
                              pin_memory=True,
                              worker_init_fn=init_fn)

    start = time.time()

    enum_batches = len(train_set) / float(
        args.batch_size)  # nums_batch per epoch

    losses = AverageMeter()
    torch.set_grad_enabled(True)

    for i, data in enumerate(train_loader, args.start_iter):

        elapsed_bsize = int(i / enum_batches) + 1
        epoch = int((i + 1) / enum_batches)
        setproctitle.setproctitle("Epoch:{}/{}".format(elapsed_bsize,
                                                       args.num_epochs))

        # actual training
        adjust_learning_rate(optimizer, epoch, args.num_epochs,
                             args.opt_params.lr)

        data = [t.cuda(non_blocking=True) for t in data]
        x, target = data[:2]

        output = model(x)

        if not args.weight_type:  # compatible for the old version
            args.weight_type = 'square'

        if args.criterion_kwargs is not None:
            loss = criterion(output, target, **args.criterion_kwargs)
        else:
            loss = criterion(output, target)

        # measure accuracy and record loss
        losses.update(loss.item(), target.numel())

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % int(enum_batches * args.save_freq) == 0 \
            or (i+1) % int(enum_batches * (args.num_epochs -1))==0\
            or (i+1) % int(enum_batches * (args.num_epochs -2))==0\
            or (i+1) % int(enum_batches * (args.num_epochs -3))==0\
            or (i+1) % int(enum_batches * (args.num_epochs -4))==0:

            file_name = os.path.join(ckpts, 'model_epoch_{}.pth'.format(epoch))
            torch.save(
                {
                    'iter': i,
                    'state_dict': model.state_dict(),
                    'optim_dict': optimizer.state_dict(),
                }, file_name)

        msg = 'Iter {0:}, Epoch {1:.4f}, Loss {2:.7f}'.format(
            i + 1, (i + 1) / enum_batches, losses.avg)
        logging.info(msg)

        losses.reset()

    i = num_iters + args.start_iter
    file_name = os.path.join(ckpts, 'model_last.pth')
    torch.save(
        {
            'iter': i,
            'state_dict': model.state_dict(),
            'optim_dict': optimizer.state_dict(),
        }, file_name)

    msg = 'total time: {:.4f} minutes'.format((time.time() - start) / 60)
    logging.info(msg)