コード例 #1
0
ファイル: train.py プロジェクト: xuexiy1ge/AI4K
def main(args):
    print("===> Loading datasets")
    file_name = sorted(os.listdir(args.data_lr))
    lr_list = []
    hr_list = []
    for one in file_name:
        lr_tmp = sorted(glob(os.path.join(args.data_lr, one, '*.png')))
        lr_list.extend(lr_tmp)
        hr_tmp = sorted(glob(os.path.join(args.data_hr, one, '*.png')))
        if len(hr_tmp) != 100:
            print(one)
        hr_list.extend(hr_tmp)

    # lr_list = glob(os.path.join(args.data_lr, '*'))
    # hr_list = glob(os.path.join(args.data_hr, '*'))
    # lr_list = lr_list[0:max_index]
    # hr_list = hr_list[0:max_index]

    data_set = DatasetLoader(lr_list, hr_list, args.patch_size, args.scale)
    data_len = len(data_set)
    train_loader = DataLoader(data_set,
                              batch_size=args.batch_size,
                              num_workers=args.workers,
                              shuffle=True,
                              pin_memory=True,
                              drop_last=True)

    print("===> Building model")
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    cudnn.benchmark = True

    device_ids = list(range(args.gpus))
    model = EDSR(args)
    criterion = nn.L1Loss(reduction='sum')

    print("===> Setting GPU")
    model = nn.DataParallel(model, device_ids=device_ids)
    model = model.cuda()
    criterion = criterion.cuda()

    start_epoch = args.start_epoch
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isdir(args.resume):
            #获取目录中最后一个
            pth_list = sorted(glob(os.path.join(args.resume, '*.pth')))
            if len(pth_list) > 0:
                args.resume = pth_list[-1]
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)

            start_epoch = checkpoint['epoch'] + 1
            state_dict = checkpoint['state_dict']
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                namekey = 'module.' + k  # remove `module.`
                new_state_dict[namekey] = v
            model.load_state_dict(new_state_dict)

            #如果文件中有lr,则不用启动参数
            args.lr = checkpoint.get('lr', args.lr)

    if args.start_epoch != 0:
        #如果设置了 start_epoch 则不用checkpoint中的epoch参数
        start_epoch = args.start_epoch

    print("===> Setting Optimizer")
    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=args.lr,
                           weight_decay=args.weight_decay,
                           betas=(0.9, 0.999),
                           eps=1e-08)

    # record = []
    print("===> Training")
    for epoch in range(start_epoch, args.epochs):
        adjust_lr(optimizer, epoch)

        losses, psnrs = one_epoch_train_logger(model, optimizer, criterion,
                                               data_len, train_loader, epoch,
                                               args.epochs, args.batch_size,
                                               optimizer.param_groups[0]["lr"])

        # lr = optimizer.param_groups[0]["lr"]
        # the_lr = 1e-2
        # lr_len = 2
        # while lr + (1e-9) < the_lr:
        #     the_lr *= 0.1
        #     lr_len += 1
        # record.append([losses.avg,psnrs.avg,lr_len])

        # save model
        # if epoch+1 != args.epochs:
        #     continue
        model_out_path = os.path.join(
            args.checkpoint, "model_epoch_%04d_edsr_loss_%.3f_psnr_%.3f.pth" %
            (epoch, losses.avg, psnrs.avg))
        if not os.path.exists(args.checkpoint):
            os.makedirs(args.checkpoint)
        torch.save(
            {
                'state_dict': model.module.state_dict(),
                "epoch": epoch,
                'lr': optimizer.param_groups[0]["lr"]
            }, model_out_path)
コード例 #2
0
if args.verbose:
    print("{} training images / {} testing images".format(
        len(train_set), len(test_set)))
    print("===> dataset loaded !")
    print('===> Building model')

model = EDSR(upscale_factor,
             input_channels,
             target_channels,
             n_resblocks=n_resblocks,
             n_feats=n_patch_features,
             res_scale=.1,
             bn=None).to(device)

criterion = nn.MSELoss()
optimizer = optim.AdamW(model.parameters(), lr=lr)

if args.verbose:
    print(model.parameters())
    print(model)
    print("==> Model built")


def train(epoch):
    epoch_loss = 0
    epoch_loss_indiv = [0 for x in range(len(target_channels))]
    epoch_ssim_indiv = [0 for x in range(len(target_channels))]
    for iteration, batch in enumerate(training_data_loader, 1):
        inp, target = batch[0].to(device), batch[1].to(device)

        optimizer.zero_grad()