class CustomSimCLR50(nn.Module):
    def __init__(self, batch_size, latent_dim, head_dim):
        super(CustomSimCLR50, self).__init__()
        self.pretrained_model = GetModel('Resnet50')
        self.load_weights('./resnet50-1x.pth')
        self.features = torch.nn.Sequential(
            *list(self.pretrained_model.children())[:-1])
        self.features.add_module('Final Convolution',
                                 nn.Conv2d(2048, latent_dim, 1))
        self.head = SimCLR_head(latent_dim, head_dim)
        self.test_shape(batch_size, latent_dim)

    def test_shape(self, batch_size, latent_dim):
        test_dimensionality(self.features, torch.rand(batch_size, 3, 128, 128),
                            'feat')
        test_dimensionality(self.head, torch.rand(batch_size, latent_dim, 1,
                                                  1), 'head')

    def load_weights(self, path):
        sd = torch.load(path, map_location='cpu')
        self.pretrained_model.load_state_dict(sd['state_dict'])

    def forward(self, x):
        x = self.features(x)
        return x, self.head(x)
 def __init__(self, batch_size, latent_dim, head_dim, pretrained):
     super(CustomSimCLR18, self).__init__()
     self.model = GetModel('Resnet18', pretrained=pretrained)
     self.features = torch.nn.Sequential(*list(self.model.children())[:-1])
     self.features.add_module('Final Convolution',
                              nn.Conv2d(512, latent_dim, 1))
     self.head = SimCLR_head(latent_dim, head_dim)
     self.test_shape(batch_size, latent_dim)
 def __init__(self, batch_size, latent_dim, head_dim):
     super(CustomSimCLR50, self).__init__()
     self.pretrained_model = GetModel('Resnet50')
     self.load_weights('./resnet50-1x.pth')
     self.features = torch.nn.Sequential(
         *list(self.pretrained_model.children())[:-1])
     self.features.add_module('Final Convolution',
                              nn.Conv2d(2048, latent_dim, 1))
     self.head = SimCLR_head(latent_dim, head_dim)
     self.test_shape(batch_size, latent_dim)
class CustomSimCLR18(nn.Module):
    def __init__(self, batch_size, latent_dim, head_dim, pretrained):
        super(CustomSimCLR18, self).__init__()
        self.model = GetModel('Resnet18', pretrained=pretrained)
        self.features = torch.nn.Sequential(*list(self.model.children())[:-1])
        self.features.add_module('Final Convolution',
                                 nn.Conv2d(512, latent_dim, 1))
        self.head = SimCLR_head(latent_dim, head_dim)
        self.test_shape(batch_size, latent_dim)

    def test_shape(self, batch_size, latent_dim):
        test_dimensionality(self.features, torch.rand(batch_size, 3, 128, 128),
                            'feat')
        test_dimensionality(self.head, torch.rand(batch_size, latent_dim, 1,
                                                  1), 'head')

    def forward(self, x):
        '''model outputs : hidden representations [bs x LATENT_DIM x 1 x 1] which will be used as an input to RL model
                       nonlinear projections [bs x HEAD_DIM] which is used to compute the contrastive loss and optimization
    '''
        x = self.features(x)
        return torch.flatten(x, 1), self.head(x)
예제 #5
0
파일: run.py 프로젝트: BoveyZheng/Deblur
    torch.save(checkpoint, opt.out + '/final.pth')


if __name__ == '__main__':

    try:
        os.makedirs(opt.out)
    except IOError:
        pass

    opt.fid = open(opt.out + '/log.txt', 'w')
    print(opt)
    print(opt, '\n', file=opt.fid)

    dataloader, validloader = GetDataloaders(opt)
    net = GetModel(opt)

    if opt.log:
        opt.writer = SummaryWriter(
            comment='_%s_%s' %
            (opt.out.replace('\\', '/').split('/')[-1], opt.model))
        opt.train_stats = open(
            opt.out.replace('\\', '/') + '/train_stats.csv', 'w')
        opt.test_stats = open(
            opt.out.replace('\\', '/') + '/test_stats.csv', 'w')
        print('iter,nsample,time,memory,meanloss', file=opt.train_stats)
        print('iter,time,memory,psnr,ssim', file=opt.test_stats)

    import time
    t0 = time.perf_counter()
    if not opt.test:
예제 #6
0
def main(opt):
    opt.device = torch.device(
        'cuda' if torch.cuda.is_available() and not opt.cpu else 'cpu')

    os.makedirs(opt.out, exist_ok=True)
    shutil.copy2('options.py', opt.out)

    opt.fid = open(opt.out + '/log.txt', 'w')

    ostr = 'ARGS: ' + ' '.join(sys.argv[:])
    print(opt, '\n')
    print(opt, '\n', file=opt.fid)
    print('\n%s\n' % ostr)
    print('\n%s\n' % ostr, file=opt.fid)

    print('getting dataloader', opt.root)
    dataloader, validloader = GetDataloaders(opt)

    if opt.log:
        opt.writer = SummaryWriter(
            log_dir=opt.out,
            comment='_%s_%s' %
            (opt.out.replace('\\', '/').split('/')[-1], opt.model))
        opt.train_stats = open(
            opt.out.replace('\\', '/') + '/train_stats.csv', 'w')
        opt.test_stats = open(
            opt.out.replace('\\', '/') + '/test_stats.csv', 'w')
        print('iter,nsample,time,memory,meanloss', file=opt.train_stats)
        print('iter,time,memory,psnr,ssim', file=opt.test_stats)

    t0 = time.perf_counter()
    net = GetModel(opt)

    if not opt.test:
        train(opt, dataloader, validloader, net)
        # torch.save(net.state_dict(), opt.out + '/final.pth')
    else:
        if len(opt.weights) > 0:  # load previous weights?
            checkpoint = torch.load(opt.weights)
            print('loading checkpoint', opt.weights)
            net.load_state_dict(checkpoint['state_dict'])
            print('time: %0.1f' % (time.perf_counter() - t0))
        testAndMakeCombinedPlots(net, validloader, opt)

    opt.fid.close()
    if not opt.test:
        generate_convergence_plots(opt, opt.out + '/log.txt')

    print('time: %0.1f' % (time.perf_counter() - t0))

    # optional clean up
    if opt.disposableTrainingData and not opt.test:
        print('deleting training data')
        # preserve a few samples
        os.makedirs('%s/training_data_subset' % opt.out, exist_ok=True)

        samplecount = 0
        for file in glob.glob('%s/*' % opt.root):
            if os.path.isfile(file):
                basename = os.path.basename(file)
                shutil.copy2(
                    file, '%s/training_data_subset/%s' % (opt.out, basename))
                samplecount += 1
                if samplecount == 10:
                    break
        shutil.rmtree(opt.root)
예제 #7
0
    torch.save(checkpoint, opt.out + '/final.pth')


if __name__ == '__main__':

    try:
        os.makedirs(opt.out)
    except IOError:
        pass

    opt.fid = open(opt.out + '/log.txt', 'w')
    print(opt)
    print(opt, '\n', file=opt.fid)

    dataloader, validloader = GetDataloaders(opt)
    net = GetModel(opt)
    net.cpu()
    loss_function = nn.MSELoss()
    optimizer = optim.Adam(net.parameters(), lr=opt.lr)

    if len(opt.weights) > 0:  # load previous weights?
        checkpoint = torch.load(opt.weights)
        print('loading checkpoint', opt.weights)
        if opt.undomulti:
            checkpoint['state_dict'] = remove_dataparallel_wrapper(
                checkpoint['state_dict'])
        else:
            net.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            start_epoch = checkpoint['epoch']