Esempio n. 1
0
train_set = get_training_set(opt.dataset)
val_set = get_val_set(opt.dataset)
training_data_loader = DataLoader(dataset=train_set,
                                  batch_size=opt.batchSize,
                                  shuffle=True)
val_data_loader = DataLoader(dataset=val_set,
                             batch_size=opt.valBatchSize,
                             shuffle=False)

print('===> Loading pre_train model and Building model')
model_r = LapSRN_r().to(device)
model_g = LapSRN_g().to(device)
Loss = Loss()
criterion = nn.MSELoss()
if cuda:
    Loss = Loss.cuda()
    criterion = criterion.cuda()


def train(epoch):
    epoch_loss = 0
    for iteration, batch in enumerate(training_data_loader, 1):
        LR_r, LR_g, HR_2_target, HR_4_target = batch[0].to(device), batch[
            1].to(device), batch[2].to(device), batch[3].to(device)

        optimizer_r.zero_grad()
        optimizer_g.zero_grad()

        HR_2_r, HR_4_r = model_r(LR_r)
        HR_2_g, HR_4_g = model_g(LR_g)
Esempio n. 2
0
class Solver(object):
    def __init__(self, train_loader, val_loader, test_loader, config):
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.config = config
        self.mean = torch.Tensor([123.68, 116.779, 103.939]).view(3, 1,
                                                                  1) / 255
        self.beta = 0.3
        self.device = torch.device('cpu')
        if self.config.cuda:
            cudnn.benchmark = True
            self.device = torch.device('cuda')
        if config.visdom:
            self.visual = Viz_visdom("NLDF", 1)
        self.build_model()
        if self.config.pre_trained:
            self.net.load_state_dict(torch.load(self.config.pre_trained))
        if config.mode == 'train':
            self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w')
        else:
            self.net.load_state_dict(torch.load(self.config.model))
            self.net.eval()
            self.test_output = open("%s/test.txt" % config.test_fold, 'w')

    def print_network(self, model, name):
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(name)
        print(model)
        print("The number of parameters: {}".format(num_params))

    def build_model(self):
        self.net = build_model()
        if self.config.mode == 'train':
            self.loss = Loss(self.config.area, self.config.boundary)
        self.net = self.net.to(self.device)
        if self.config.cuda and self.config.mode == 'train':
            self.loss = self.loss.cuda()
        self.net.train()
        self.net.apply(weights_init)
        if self.config.load == '':
            self.net.base.load_state_dict(torch.load(self.config.vgg))
        if self.config.load != '':
            self.net.load_state_dict(torch.load(self.config.load))
        self.optimizer = Adam(self.net.parameters(), self.config.lr)
        self.print_network(self.net, 'NLDF')

    def update_lr(self, lr):
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def clip(self, y):
        return torch.clamp(y, 0.0, 1.0)

    def eval_mae(self, y_pred, y):
        return torch.abs(y_pred - y).mean()

    # TODO: write a more efficient version
    def eval_pr(self, y_pred, y, num):
        prec, recall = torch.zeros(num), torch.zeros(num)
        thlist = torch.linspace(0, 1 - 1e-10, num)
        for i in range(num):
            y_temp = (y_pred >= thlist[i]).float()
            tp = (y_temp * y).sum()
            prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / y.sum()
        return prec, recall

    def validation(self):
        avg_mae = 0.0
        self.net.eval()
        for i, data_batch in enumerate(self.val_loader):
            with torch.no_grad():
                images, labels = data_batch
                images, labels = images.to(self.device), labels.to(self.device)
                prob_pred = self.net(images)
            avg_mae += self.eval_mae(prob_pred, labels).cpu().item()
        self.net.train()
        return avg_mae / len(self.val_loader)

    def test(self, num):
        avg_mae, img_num = 0.0, len(self.test_loader)
        avg_prec, avg_recall = torch.zeros(num), torch.zeros(num)
        for i, data_batch in enumerate(self.test_loader):
            with torch.no_grad():
                images, labels = data_batch
                shape = labels.size()[2:]
                images = images.to(self.device)
                prob_pred = F.interpolate(self.net(images),
                                          size=shape,
                                          mode='bilinear',
                                          align_corners=True).cpu()
            mae = self.eval_mae(prob_pred, labels)
            prec, recall = self.eval_pr(prob_pred, labels, num)
            print("[%d] mae: %.4f" % (i, mae))
            print("[%d] mae: %.4f" % (i, mae), file=self.test_output)
            avg_mae += mae
            avg_prec, avg_recall = avg_prec + prec, avg_recall + recall
        avg_mae, avg_prec, avg_recall = avg_mae / img_num, avg_prec / img_num, avg_recall / img_num
        score = (1 + self.beta**2) * avg_prec * avg_recall / (
            self.beta**2 * avg_prec + avg_recall)
        score[score != score] = 0  # delete the nan
        print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max()))
        print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max()),
              file=self.test_output)

    def train(self):
        iter_num = len(self.train_loader.dataset) // self.config.batch_size
        best_mae = 1.0 if self.config.val else None
        for epoch in range(self.config.epoch):
            loss_epoch = 0
            for i, data_batch in enumerate(self.train_loader):
                if (i + 1) > iter_num: break
                self.net.zero_grad()
                x, y = data_batch
                x, y = x.to(self.device), y.to(self.device)
                y_pred = self.net(x)
                loss = self.loss(y_pred, y)
                loss.backward()
                utils.clip_grad_norm_(self.net.parameters(),
                                      self.config.clip_gradient)
                self.optimizer.step()
                loss_epoch += loss.cpu().item()
                print(
                    'epoch: [%d/%d], iter: [%d/%d], loss: [%.4f]' %
                    (epoch, self.config.epoch, i, iter_num, loss.cpu().item()))
                if self.config.visdom:
                    error = OrderedDict([('loss:', loss.cpu().item())])
                    self.visual.plot_current_errors(epoch, i / iter_num, error)
            if (epoch + 1) % self.config.epoch_show == 0:
                print('epoch: [%d/%d], epoch_loss: [%.4f]' %
                      (epoch, self.config.epoch, loss_epoch / iter_num),
                      file=self.log_output)
                if self.config.visdom:
                    avg_err = OrderedDict([('avg_loss', loss_epoch / iter_num)
                                           ])
                    self.visual.plot_current_errors(epoch, i / iter_num,
                                                    avg_err, 1)
                    img = OrderedDict([('origin', self.mean + x.cpu()[0]),
                                       ('label', y.cpu()[0][0]),
                                       ('pred_label', y_pred.cpu()[0][0])])
                    self.visual.plot_current_img(img)
            if self.config.val and (epoch + 1) % self.config.epoch_val == 0:
                mae = self.validation()
                print('--- Best MAE: %.4f, Curr MAE: %.4f ---' %
                      (best_mae, mae))
                print('--- Best MAE: %.4f, Curr MAE: %.4f ---' %
                      (best_mae, mae),
                      file=self.log_output)
                if best_mae > mae:
                    best_mae = mae
                    torch.save(self.net.state_dict(),
                               '%s/models/best.pth' % self.config.save_fold)
            if (epoch + 1) % self.config.epoch_save == 0:
                torch.save(
                    self.net.state_dict(), '%s/models/epoch_%d.pth' %
                    (self.config.save_fold, epoch + 1))
        torch.save(self.net.state_dict(),
                   '%s/models/final.pth' % self.config.save_fold)
Esempio n. 3
0
class Solver(object):
    def __init__(self, train_loader, val_loader, test_loader, config):
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.config = config
        self.beta = 0.3  # for max F_beta metric
        self.mean = torch.Tensor([123.68, 116.779, 103.939]).view(3, 1,
                                                                  1) / 255
        # inference: choose the side map (see paper)
        self.select = [1, 2, 3, 6]
        if config.visdom:
            self.visual = Viz_visdom("DSS", 1)
        self.build_model()
        if self.config.pre_trained:
            self.net.load_state_dict(torch.load(self.config.pre_trained))
        if config.mode == 'train':
            self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w')
        else:
            self.net.load_state_dict(torch.load(self.config.model))
            self.net.eval()
            self.test_output = open("%s/test.txt" % config.test_fold, 'w')

    # print the network information and parameter numbers
    def print_network(self, model, name):
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(name)
        print(model)
        print("The number of parameters: {}".format(num_params))

    # build the network
    def build_model(self):
        self.net = build_model()
        if self.config.mode == 'train': self.loss = Loss()
        if self.config.cuda: self.net = self.net.cuda()
        if self.config.cuda and self.config.mode == 'train':
            self.loss = self.loss.cuda()
        self.net.train()
        self.net.apply(weights_init)
        if self.config.load == '':
            self.net.base.load_state_dict(torch.load(self.config.vgg))
        if self.config.load != '':
            self.net.load_state_dict(torch.load(self.config.load))
        self.optimizer = Adam(self.net.parameters(), self.config.lr)
        self.print_network(self.net, 'DSS')

    # update the learning rate
    def update_lr(self, lr):
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    # evaluate MAE (for test or validation phase)
    def eval_mae(self, y_pred, y):
        return torch.abs(y_pred - y).mean()

    # TODO: write a more efficient version
    # get precisions and recalls: threshold---divided [0, 1] to num values
    def eval_pr(self, y_pred, y, num):
        prec, recall = torch.zeros(num), torch.zeros(num)
        thlist = torch.linspace(0, 1 - 1e-10, num)
        for i in range(num):
            y_temp = (y_pred >= thlist[i]).float()
            tp = (y_temp * y).sum()
            prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / y.sum()
        return prec, recall

    # validation: using resize image, and only evaluate the MAE metric
    def validation(self):
        avg_mae = 0.0
        self.net.eval()
        for i, data_batch in enumerate(self.val_loader):
            images, labels = data_batch
            images, labels = Variable(images,
                                      volatile=True), Variable(labels,
                                                               volatile=True)
            if self.config.cuda:
                images, labels = images.cuda(), labels.cuda()
            prob_pred = self.net(images)
            prob_pred = torch.mean(torch.cat(
                [prob_pred[i] for i in self.select], dim=1),
                                   dim=1,
                                   keepdim=True)
            avg_mae += self.eval_mae(prob_pred, labels).cpu().data[0]
        self.net.train()
        return avg_mae / len(self.val_loader)

    # test phase: using origin image size, evaluate MAE and max F_beta metrics
    def test(self, num):
        avg_mae, img_num = 0.0, len(self.test_loader)
        avg_prec, avg_recall = torch.zeros(num), torch.zeros(num)
        for i, data_batch in enumerate(self.test_loader):
            images, labels = data_batch
            shape = labels.size()[2:]
            images = Variable(images, volatile=True)
            if self.config.cuda:
                images = images.cuda()
            prob_pred = self.net(images)
            prob_pred = torch.mean(torch.cat(
                [prob_pred[i] for i in self.select], dim=1),
                                   dim=1,
                                   keepdim=True)
            prob_pred = F.upsample(prob_pred, size=shape,
                                   mode='bilinear').cpu().data
            mae = self.eval_mae(prob_pred, labels)
            prec, recall = self.eval_pr(prob_pred, labels, num)
            print("[%d] mae: %.4f" % (i, mae))
            print("[%d] mae: %.4f" % (i, mae), file=self.test_output)
            avg_mae += mae
            avg_prec, avg_recall = avg_prec + prec, avg_recall + recall
        avg_mae, avg_prec, avg_recall = avg_mae / img_num, avg_prec / img_num, avg_recall / img_num
        score = (1 + self.beta**2) * avg_prec * avg_recall / (
            self.beta**2 * avg_prec + avg_recall)
        score[score != score] = 0  # delete the nan
        print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max()))
        print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max()),
              file=self.test_output)

    # training phase
    def train(self):
        x = torch.FloatTensor(self.config.batch_size, self.config.n_color,
                              self.config.img_size, self.config.img_size)
        y = torch.FloatTensor(self.config.batch_size, self.config.n_color,
                              self.config.img_size, self.config.img_size)
        if self.config.cuda:
            cudnn.benchmark = True
            x, y = x.cuda(), y.cuda()
        x, y = Variable(x), Variable(y)
        iter_num = len(self.train_loader.dataset) // self.config.batch_size
        best_mae = 1.0 if self.config.val else None
        for epoch in range(self.config.epoch):
            loss_epoch = 0
            for i, data_batch in enumerate(self.train_loader):
                if (i + 1) > iter_num: break
                self.net.zero_grad()
                images, labels = data_batch
                if self.config.cuda:
                    images, labels = images.cuda(), labels.cuda()
                x.data.resize_as_(images).copy_(images)
                y.data.resize_as_(labels).copy_(labels)
                y_pred = self.net(x)
                loss = self.loss(y_pred, y)
                loss.backward()
                utils.clip_grad_norm(self.net.parameters(),
                                     self.config.clip_gradient)
                # utils.clip_grad_norm(self.loss.parameters(), self.config.clip_gradient)
                self.optimizer.step()
                loss_epoch += loss.cpu().data[0]
                print('epoch: [%d/%d], iter: [%d/%d], loss: [%.4f]' %
                      (epoch, self.config.epoch, i, iter_num,
                       loss.cpu().data[0]))
                if self.config.visdom:
                    error = OrderedDict([('loss:', loss.cpu().data[0])])
                    self.visual.plot_current_errors(epoch, i / iter_num, error)
            if (epoch + 1) % self.config.epoch_show == 0:
                print('epoch: [%d/%d], epoch_loss: [%.4f]' %
                      (epoch, self.config.epoch, loss_epoch / iter_num),
                      file=self.log_output)
                if self.config.visdom:
                    avg_err = OrderedDict([('avg_loss', loss_epoch / iter_num)
                                           ])
                    self.visual.plot_current_errors(epoch, i / iter_num,
                                                    avg_err, 1)
                    y_show = torch.mean(torch.cat(
                        [y_pred[i] for i in self.select], dim=1),
                                        dim=1,
                                        keepdim=True)
                    img = OrderedDict([('origin', self.mean + images.cpu()[0]),
                                       ('label', labels.cpu()[0][0]),
                                       ('pred_label', y_show.cpu().data[0][0])
                                       ])
                    self.visual.plot_current_img(img)
            if self.config.val and (epoch + 1) % self.config.epoch_val == 0:
                mae = self.validation()
                print('--- Best MAE: %.2f, Curr MAE: %.2f ---' %
                      (best_mae, mae))
                print('--- Best MAE: %.2f, Curr MAE: %.2f ---' %
                      (best_mae, mae),
                      file=self.log_output)
                if best_mae > mae:
                    best_mae = mae
                    torch.save(self.net.state_dict(),
                               '%s/models/best.pth' % self.config.save_fold)
            if (epoch + 1) % self.config.epoch_save == 0:
                torch.save(
                    self.net.state_dict(), '%s/models/epoch_%d.pth' %
                    (self.config.save_fold, epoch + 1))
        torch.save(self.net.state_dict(),
                   '%s/models/final.pth' % self.config.save_fold)
Esempio n. 4
0
device = torch.device("cuda" if opt.cuda else "cpu")

print('===> Loading datasets')
train_set = get_training_set(opt.dataset)
val_set = get_val_set(opt.dataset)
training_data_loader = DataLoader(dataset=train_set, batch_size=opt.batchSize, shuffle=True)
val_data_loader = DataLoader(dataset=val_set, batch_size=opt.valBatchSize, shuffle=False)


print('===> Loading pre_train model and Building model')

Loss_r = Loss()
Loss_g = Loss()
criterion = nn.MSELoss()
if cuda:
    Loss_r = Loss_r.cuda()
    Loss_g = Loss_g.cuda()
    criterion = criterion.cuda()


def train(epoch):
        epoch_loss_r, epoch_loss_g = 0, 0
        for _, batch in enumerate(training_data_loader, 1):
            LR_r, LR_g, HR_2_r_Target, HR_2_g_Target, HR_4_r_Target, HR_4_g_Target = \
                batch[0].to(device), batch[1].to(device), batch[2].to(device), \
                batch[3].to(device), batch[4].to(device), batch[5].to(device)

            optimizer_r.zero_grad()
            HR_2_r, HR_4_r = model_r(LR_r)
            loss_r_X2 = Loss_r(HR_2_r, HR_2_r_Target)
            loss_r_X4 = Loss_r(HR_4_r, HR_4_r_Target)
Esempio n. 5
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=20,
                        metavar='N',
                        help='number of epochs to train (default: 20)')
    parser.add_argument('--start-epoch',
                        type=int,
                        default=1,
                        metavar='N',
                        help='starting epoch (default: 1)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.01,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=100,
        metavar='N',
        help='how many batches to wait before logging training status')

    parser.add_argument('--num-workers',
                        type=int,
                        default=64,
                        metavar='N',
                        help='thread num for dataloader')

    parser.add_argument('--save-freq',
                        default=5,
                        type=int,
                        metavar='S',
                        help='save frequency')

    parser.add_argument('--save-dir',
                        default='',
                        type=str,
                        metavar='SAVE',
                        help='directory to save checkpoint (default: none)')

    parser.add_argument('--resume',
                        default='',
                        type=str,
                        metavar='PATH',
                        help='path to latest checkpoint (default: none)')

    parser.add_argument('--data-dir',
                        type=str,
                        default='',
                        metavar='N',
                        help='data dir where raw files located')

    parser.add_argument('--split-ratio',
                        type=float,
                        default=0.1,
                        metavar='N',
                        help='split ratio of validation set')

    parser.add_argument(
        '--train-full',
        type=int,
        default=0,
        metavar='N',
        help='indicate if all training set is used for training')

    args = parser.parse_args()
    # pre-process data
    data_dir = args.data_dir
    split_ratio = args.split_ratio
    preprocess_data(data_dir, split_ratio)
    save_dir = args.save_dir if args.save_dir else data_dir

    #set up gpu, use CUDA_VISIBLE_DEVICES to control gpu is often preferable
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    train_loader, test_loader = prepare_dataloader(data_dir,
                                                   args,
                                                   train_full=args.train_full)
    model = Net()
    loss_form = Loss()
    if use_cuda:
        model.cuda()
        loss_form.cuda()

    start_epoch = 1
    if args.resume:
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
        start_epoch = checkpoint['epoch'] + 1
    if args.start_epoch > 1:
        start_epoch = args.start_epoch
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum)

    # save logs to file for retrospective checking
    logfile = os.path.join(save_dir, 'log')
    sys.stdout = Logger(logfile)

    # save the .py files to retain the scene for possible future comparison
    pyfiles = [f for f in os.listdir('./') if f.endswith('.py')]
    if not os.path.exists(os.path.join(save_dir, 'code')):
        os.mkdir(os.path.join(save_dir, 'code'))
    for f in pyfiles:
        shutil.copy(f, os.path.join(save_dir, 'code', f))

    # training and testing (validation)
    for epoch in range(start_epoch, args.epochs + 1):
        train(args, model, train_loader, loss_form, use_cuda, optimizer, epoch)
        test(args, model, test_loader, loss_form, use_cuda)
        if epoch % args.save_freq == 0:
            state_dict = model.state_dict()
            for key in state_dict.keys():
                state_dict[key] = state_dict[key].cpu()

            torch.save(
                {
                    'epoch': epoch,
                    'save_dir': save_dir,
                    'state_dict': state_dict,
                    'args': args
                }, os.path.join(save_dir, '%03d.ckpt' % epoch))
Esempio n. 6
0
print('---------- Networks architecture -------------')
print_network(model)
print('----------------------------------------------')

if opt.pretrained:
    model_name = os.path.join(opt.save_folder + opt.pretrained_sr)
    if os.path.exists(model_name):

        model.load_state_dict(
            torch.load(model_name, map_location=lambda storage, loc: storage))
        print('Pre-trained SR model is loaded.')

print('Pretrained ColorNet is loaded')
if cuda:
    model = model.cuda(gpus_list[0])
    criterion = criterion.cuda(gpus_list[0])

if __name__ == "__main__":
    for _ in range(2):
        optimizer = optim.Adam(model.parameters(),
                               lr=opt.lr,
                               betas=(0.9, 0.999),
                               eps=1e-8)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=opt.lr_step, gamma=0.5)
        for epoch in range(opt.start_epoch, opt.nEpochs + 1):
            avg_loss = train(epoch)
            scheduler.step()

            if (epoch + 1) % (opt.snapshots) == 0 and (epoch + 1) > 0:
                checkpoint(epoch)
Esempio n. 7
0
    # instantiate network
    print("===> Building model")
    devices_ids = list(range(opt.n_gpus))
    net = Net()

    # if running on GPU and we want to use cuda move model there
    print("===> Setting GPU")
    cuda = opt.gpu_mode
    if cuda and not torch.cuda.is_available():
        raise Exception("No GPU found, please run without --cuda")
    net = nn.DataParallel(net, device_ids=devices_ids)
    net = net.cuda()

    # create loss
    criterion_L1_cb = Loss(eps=1e-3)
    criterion_L1_cb = criterion_L1_cb.cuda()

    print('---------- Networks architecture -------------')
    print_network(net)
    print('----------------------------------------------')

    # optionally ckp from a checkpoint
    if opt.resume:
        if opt.resume_dir != None:
            if isinstance(net, torch.nn.DataParallel):
                net.module.load_state_dict(torch.load(opt.resume_dir))
            else:
                net.load_state_dict(torch.load(opt.resume_dir))
            print('Net work loaded from {}'.format(opt.resume_dir))

    # create optimizer