Пример #1
0
def main(argv):
    
    data_path = None
    input_image = None
    output = None
    weight_path = None
    mode=5
    drop_rate=0
    lab=s.lab
    classification=False
    temp=.4
    try:
        opts, args = getopt.getopt(argv,"w:p:b:m:ld:ct:i:o:",["weight-path=", "datapath=",'model=','lab','drop-rate=','input=','output='])
    except getopt.GetoptError as error:
        print(error)
        print( 'demo.py -w <path to weights file> -p <path to folder of images> OR -i <path to single image> -l <no argument. use if lab should be used>\
            -d <amount of dropout used in model> -c <no argument. Use if model is classifier> -t <temperature for annealed mean> -o <output path for images>')
        sys.exit(2)
    print("opts", opts)
    for opt, arg in opts:
        if opt in ("-w", "--weight-path"):
            weight_path = arg
        elif opt in ("--datapath", "-p"):
            data_path = arg
        elif opt=='-m':
            if arg in ('custom','0'):
                mode = 0
            elif arg in ('u','1','unet'):
                mode = 1
            elif arg in ('ende','2'):
                mode = 2
            elif arg in ('richzhang','classende','3'):
                mode = 3
            elif arg in ('colorunet','cu','4'):
                mode = 4
            elif arg in ('mu','5','middle'):
                mode = 5
        elif opt in ('-l','--lab'):
            lab=True
        elif opt in ("-d", "--drop-rate"):
            drop_rate = float(arg) 
        elif opt =='-c':
            classification=True
            lab=True
        elif opt=='-t':
            temp=float(arg)
        elif opt in ('-i','--input'):
            input_image = arg
        elif opt in ('-o','--output'):
            output = arg

    if data_path is None and input_image is None:
        print('Please select an image or folder')
        sys.exit()
    trafo=transforms.Compose([transforms.Grayscale(3 if lab else 1), transforms.Resize((96,96))])
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if data_path is not None: 
        dataset = ImageDataset(data_path,lab=lab,pretrafo=trafo)
        loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
    if input_image is not None:
        img=trafo(Image.open(input_image))
        if lab:
            img=color.rgb2lab(np.asarray(img)/255)[...,:1]-np.array([50])[None,None,:]
        loader = [(transforms.ToTensor()(img)[None,...].float(),input_image)]
        
    
    classes=(150 if classification else 2) if lab else 3

    #define model
    UNet=None
    zoom=False
    if mode == 0:
        UNet=model(col_channels=classes) 
    elif mode ==1:
        UNet=unet(drop_rate=drop_rate,classes=classes)
    elif mode ==2:
        UNet=generator(drop_rate,classes)
    elif mode == 3:
        UNet=richzhang(drop_rate,classes)
        zoom=True
    elif mode == 4:
        UNet=color_unet(True,drop_rate,classes)
    elif mode == 5:
        UNet = middle_unet(True,drop_rate,classes)
    #load weights
    try:
        UNet.load_state_dict(torch.load(weight_path, map_location=device))
        print("Loaded network weights from", weight_path)
    except FileNotFoundError:
        print("Did not find weight files.")
        sys.exit()
    outpath=None
    UNet.to(device)  
    UNet.eval()
    with torch.no_grad():
        for i,(X,name) in enumerate(loader):
            X=X.to(device)
            unet_col=UNet(X)
            col=show_colorization(unet_col,original=X,lab=lab,cl=classification,zoom=zoom,T=temp,return_img=output is not None)
            if output:
                try:
                    fp,f=os.path.split(name)
                except TypeError:
                    fp,f=os.path.split(name[0])
                n,e=f.split('.')
                f='.'.join((n+'_color','png'))
                outpath=output if os.path.isdir(output) or os.path.isdir(os.path.basename(output)) else fp
                Image.fromarray(toInt(col[0])).save(os.path.join(outpath,f))
    if output:
        print('Finished colorization. Go to "%s" to see the colorized version(s) of the image(s)'%os.path.realpath(outpath))
def main(argv):
    # setting argument defaults
    mbsize = s.batch_size
    report_freq = s.report_freq
    weight_path = s.weights_path
    weights_name = s.weights_name
    lr = s.learning_rate
    save_freq = s.save_freq
    mode = 3
    image_loss_weight = s.image_loss_weight
    epochs = s.epochs
    beta1, beta2 = s.betas
    infinite_loop = s.infinite_loop
    data_path = s.data_path
    drop_rate = 0
    lab = True
    weighted_loss = True
    weight_lambda = .25
    load_list = s.load_list
    help = 'train_classification.py -b <batch size> -e <amount of epochs to train. standard: infinite> -r <report frequency> -w <path to weights folder> \
            -n <name> -s <save freq.> -l <learning rate> -p <path to data set> -d <dropout rate> -m <mode: differnet models> --beta1 <beta1 for adam>\
            --beta2 <beta2 for adam> --lab <No argument. If used lab colorspace is cused> --weighted <No argument. If used *NO* class weights are used> \
            --lambda <hyperparameter for class weights>'

    try:
        opts, args = getopt.getopt(argv, "he:b:r:w:l:s:n:p:d:i:m:", [
            'epochs=', "mbsize=", "report-freq=", 'weight-path=', 'lr=',
            'save-freq=', 'weight-name=', 'data_path=', 'drop_rate='
            'beta1=', 'beta2=', 'lab', 'image-loss-weight=', 'weighted',
            'mode=', 'lambda='
        ])
    except getopt.GetoptError:
        print(help)
        sys.exit(2)
    print("opts", opts)
    for opt, arg in opts:
        if opt == '-h':
            print(help)
            sys.exit()
        elif opt in ("-b", "--mbsize"):
            mbsize = int(arg)
        elif opt in ("-e", "--epochs"):
            epochs = int(arg)
            infinite_loop = False
        elif opt in ('-r', '--report-freq'):
            report_freq = int(arg)
        elif opt in ("-w", "--weight-path"):
            weight_path = arg
        elif opt in ("-n", "--weight-name"):
            weights_name = arg
        elif opt in ("-s", "--save-freq"):
            save_freq = int(arg)
        elif opt in ("-l", "--lr"):
            lr = float(arg)
        elif opt in ("-p", "--data_path"):
            data_path = str(arg)
        elif opt in ("-d", "--drop_rate"):
            drop_rate = float(arg)
        elif opt == '-m':
            if arg in ('richzhang', '0', 'ende'):
                mode = 0
            elif arg in ('u', '1', 'unet'):
                mode = 1
            elif arg in ('color', '2', 'cu'):
                mode = 2
            elif arg in ('mu', '3', 'middle'):
                mode = 3
        elif opt == '--beta1':
            beta1 = float(arg)
        elif opt == '--beta2':
            beta2 = float(arg)
        elif opt == '--lab':
            lab = True
        elif opt == '--weighted':
            weighted_loss = not weighted_loss
        elif opt == '--load-list':
            load_list = not load_list
        elif opt == '--lambda':
            weight_lambda = float(arg)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dataset = None
    in_size = 256
    if 'cifar' in data_path:
        in_size = 32
        dataset = 0
    elif 'places' in data_path:
        in_size = 224
        dataset = 1
    elif 'stl' in data_path:
        in_size = 96
        dataset = 2
    in_shape = (3, in_size, in_size)

    #out_shape=(s.classes,32,32)
    betas = (beta1, beta2)
    weight_path_ending = os.path.join(weight_path, weights_name + '.pth')

    loss_path_ending = os.path.join(weight_path,
                                    weights_name + "_" + s.loss_name)

    trainset = load_trainset(data_path,
                             lab=lab,
                             normalize=False,
                             load_list=load_list)
    trainloader = torch.utils.data.DataLoader(
        trainset,
        batch_size=mbsize,
        shuffle=True,
        num_workers=2 if dataset in (0, 1) else 0)

    print("NETWORK PATH:", weight_path_ending)
    #define output channels of the model
    classes = 150
    #define model
    if mode == 0:
        classifier = generator(drop_rate, classes)
    elif mode == 1:
        classifier = unet(True, drop_rate, classes)
    elif mode == 2:
        classifier = color_unet(True, drop_rate, classes)
    elif mode == 3:
        classifier = middle_unet(True, drop_rate, classes)
    #load weights
    try:
        classifier.load_state_dict(torch.load(weight_path_ending))
        print("Loaded network weights from", weight_path)
    except FileNotFoundError:
        print("Initialize new weights for the generator.")
        #sys.exit(2)

    classifier.to(device)

    #save the hyperparameters to a JSON-file for better oranization
    model_description_path_ending = os.path.join(weight_path,
                                                 s.model_description_name)
    # initialize model dict
    try:
        with open(model_description_path_ending, "r") as file:
            model_dict = json.load(file)
    except FileNotFoundError:
        model_dict = {}

    prev_epochs = 0
    # save settings in dict if new weights are beeing initialized
    if not weights_name in model_dict.keys():
        model_dict[weights_name] = {
            "loss_name":
            loss_path_ending,
            "epochs":
            0,
            "batch_size":
            mbsize,
            "lr":
            lr,
            "lab":
            lab,
            "betas":
            betas,
            "image_loss_weight":
            image_loss_weight,
            "weighted_loss":
            weighted_loss,
            "model":
            'classification ' +
            ['richzhang', 'U-Net', 'color U-Net', 'middle U-Net'][mode]
        }
    else:
        #load specified parameters from model_dict
        params = model_dict[weights_name]
        #mbsize=params['batch_size']
        betas = params['betas']
        #lr=params['lr']
        lab = params['lab']
        image_loss_weight = params['image_loss_weight']
        weighted_loss = params['weighted_loss']
        loss_path_ending = params['loss_name']
        #memorize how many epochs already were trained if we continue training
        prev_epochs = params['epochs'] + 1

    #optimizer
    optimizer = optim.Adam(classifier.parameters(), lr=lr, betas=betas)
    class_weight_path = 'resources/class-weights.npy'
    if weighted_loss:
        weights = np.load(class_weight_path)
        if dataset == 0:
            class_weight_path = 'resources/cifar-lab-class-weights.pt'
            weights = torch.load(class_weight_path).numpy()
        elif dataset == 2:
            if weight_lambda:
                class_weight_path = 'resources/probdist_lab.pt'
                prob_dict = torch.load(class_weight_path)
                prob = np.array(list(prob_dict.values()))
                weights = 1 / ((1 - weight_lambda) * prob / prob.sum() +
                               weight_lambda / classes)
            else:
                class_weight_path = 'resources/class-weights-lab150-stl.pt'
                weights = torch.load(class_weight_path)

        print('Class-weights loaded from ' + class_weight_path)
    criterion = softCossEntropyLoss(
        weights=weights,
        device=device) if weighted_loss else softCossEntropyLoss(weights=None,
                                                                 device=device)
    loss_hist = []
    soft_onehot = torch.load('resources/smooth_onehot150.pt',
                             map_location=device)

    classifier.train()
    # run over epochs
    for e in (range(prev_epochs, prev_epochs +
                    epochs) if not infinite_loop else count(prev_epochs)):
        g_running = 0
        #load batches
        for i, batch in enumerate(trainloader):

            if dataset == 0:  #cifar 10
                (image, _) = batch
            elif dataset in (1, 2):  #places
                image = batch

            #batch_size=image.shape[0]
            if dataset == 0:  #cifar/stl 10
                image = np.transpose(image, (0, 2, 3, 1))
                image = np.transpose(color.rgb2lab(image), (0, 3, 1, 2))
                image = torch.from_numpy(
                    (image -
                     np.array([50, 0, 0])[None, :, None, None])).float()

            X = image[:, :1, :, :].to(
                device)  #set X to the Lightness of the image
            image = image[:, 1:, :, :].to(device)  #image is a and b channel

            #----------------------------------------------------------------------------------------
            ################################### Model optimization ##################################
            #----------------------------------------------------------------------------------------
            #clear gradients
            optimizer.zero_grad()
            #softmax activated distribution
            model_out = classifier(X).double()
            #create bin coded verion of ab ground truth
            binab = ab2bins(image.transpose(1, 3).transpose(1, 2))
            if mode == 0:
                binab = F.interpolate(binab.float(),
                                      scale_factor=(.25, .25)).long()
            binab = torch.squeeze(binab, 1)
            binab = soft_onehot[:, binab].transpose(0, 1).double()
            #calculate loss
            loss = criterion(model_out, binab).mean(0)

            loss.backward()
            optimizer.step()

            g_running += loss.item()
            loss_hist.append([e, loss.item()])

            #report running loss
            if (i + len(trainloader) * e) % report_freq == report_freq - 1:
                print('Epoch %i, batch %i: \tloss=%.2e' %
                      (e + 1, i + 1, g_running / report_freq))
                g_running = 0

            if s.save_weights and (
                    i + len(trainloader) * e) % save_freq == save_freq - 1:
                #save parameters
                try:
                    torch.save(classifier.state_dict(), weight_path_ending)
                    #torch.save(crit.state_dict(),crit_path)
                except FileNotFoundError:
                    os.makedirs(weight_path)
                    torch.save(classifier.state_dict(), weight_path_ending)
                    #torch.save(crit.state_dict(),crit_path)
                print("Parameters saved")

                if s.save_loss:
                    #save loss history to file
                    try:
                        f = open(loss_path_ending, 'a')
                        np.savetxt(f, loss_hist, '%e')
                        f.close()
                    except FileNotFoundError:
                        os.makedirs(s.loss_path)
                        np.savetxt(loss_path_ending, loss_hist, '%e')
                    loss_hist = []

        #update epoch count in dict after each epoch
        model_dict[weights_name]["epochs"] = e
        #save it to file
        try:
            with open(model_description_path_ending, "w") as file:
                json.dump(model_dict, file, sort_keys=True, indent=4)
        except:
            print('Could not save to model dictionary (JSON-file)')
Пример #3
0
def main(argv):

    data_path = s.data_path
    weight_path = s.weights_path
    mode = 1
    drop_rate = 0
    lab = s.lab
    classification = False
    temp = .4
    try:
        opts, args = getopt.getopt(argv, "h:w:p:b:m:ld:ct:", [
            "help", "weight-path=", "datapath=", 'model=', 'lab', 'drop-rate='
        ])
    except getopt.GetoptError as error:
        print(error)
        print(
            'test.py -w <path to weights file> -p <path to dataset> -l <no argument. use if lab should be used> -m <mode: different models>\
            -d <amount of dropout used in model> -c <no argument. Use if model is classifier> -t <temperature for annealed mean> -o <output path for images>'
        )
        sys.exit(2)
    print("opts", opts)
    for opt, arg in opts:
        if opt == '-h':
            print('test.py -i <Boolean> -s <Boolean>')
            sys.exit()
        elif opt in ("-w", "--weight-path"):
            weight_path = arg
        elif opt in ("--datapath", "-p"):
            data_path = arg
        elif opt in ("--batchnorm", "-b"):
            batch_norm = arg in ["True", "true", "1"]
        elif opt == '-m':
            if arg in ('custom', '0'):
                mode = 0
            elif arg in ('u', '1', 'unet'):
                mode = 1
            elif arg in ('ende', '2'):
                mode = 2
            elif arg in ('richzhang', 'classende', '3'):
                mode = 3
            elif arg in ('colorunet', 'cu', '4'):
                mode = 4
            elif arg in ('mu', '5', 'middle'):
                mode = 5
        elif opt in ('-l', '--lab'):
            lab = True
        elif opt in ("-d", "--drop-rate"):
            drop_rate = float(arg)
        elif opt == '-c':
            classification = True
            lab = True
        elif opt == '-t':
            temp = float(arg)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    dataset = None
    if data_path == './cifar-10':
        in_size = 32
        dataset = 0
    elif 'places' in data_path:
        in_size = 224
        dataset = 1
    elif 'stl' in data_path:
        in_size = 96
        dataset = 2
    in_shape = (3, in_size, in_size)
    #out_shape=(s.classes,32,32)

    trainset = load_trainset(data_path, train=False, lab=lab)
    trainloader = torch.utils.data.DataLoader(
        trainset,
        batch_size=3,
        shuffle=True,
        num_workers=2 if dataset in (0, 1) else 0)
    print("Loaded dataset from", data_path)
    classes = (150 if classification else 2) if lab else 3

    #define model
    UNet = None
    zoom = False
    if mode == 0:
        UNet = model(col_channels=classes)
    elif mode == 1:
        UNet = unet(drop_rate=drop_rate, classes=classes)
    elif mode == 2:
        UNet = generator(drop_rate, classes)
    elif mode == 3:
        UNet = richzhang(drop_rate, classes)
        zoom = True
    elif mode == 4:
        UNet = color_unet(True, drop_rate, classes)
    elif mode == 5:
        UNet = middle_unet(True, drop_rate, classes)
    #load weights
    try:
        UNet.load_state_dict(torch.load(weight_path, map_location=device))
        print("Loaded network weights from", weight_path)
    except FileNotFoundError:
        print("Did not find weight files.")
        #sys.exit(2)

    UNet.to(device)
    UNet.eval()
    gray = torch.tensor([0.2989, 0.5870, 0.1140])[:, None, None].float()
    with torch.no_grad():
        for i, batch in enumerate(trainloader):
            if dataset == 0:  #cifar 10
                (image, _) = batch
            elif dataset in (1, 2):  #places
                image = batch
            X = None
            if lab:
                if dataset == 0:  #cifar 10
                    image = np.transpose(image, (0, 2, 3, 1))
                    image = np.transpose(color.rgb2lab(image), (0, 3, 1, 2))
                    image = torch.from_numpy(
                        (image -
                         np.array([50, 0, 0])[None, :, None, None])).float()
                X = torch.unsqueeze(image[:, 0, :, :], 1).to(
                    device)  #set X to the Lightness of the image
                image = image[:, 1:, :, :]  #image is a and b channel
            else:
                #convert to grayscale image

                #using the matlab formula: 0.2989 * R + 0.5870 * G + 0.1140 * B and load data to gpu
                X = (image.clone() * gray).sum(1).to(device).view(
                    -1, 1, *in_shape[1:])
                image = image.float()
            #print(X.min(),X.max())
            #generate colorized version with unet
            #for arr in (image[:,0,...],image[:,1,...],X):
            #    print(arr.min(),arr.max())
            try:
                unet_col = UNet(X)
            except:
                unet_col = UNet(torch.stack((X, X, X), 1)[:, :, 0, :, :])
            #for arr in (unet_col[0,...],unet_col[1,...]):
            #    print(arr.min().item(),arr.max().item())
            show_colorization(unet_col,
                              image,
                              X,
                              lab=lab,
                              cl=classification,
                              zoom=zoom,
                              T=temp)
Пример #4
0
def main(argv):
    # setting argument defaults
    mbsize = s.batch_size
    report_freq=s.report_freq
    weight_path=s.weights_path
    weights_name=s.weights_name
    lr=s.learning_rate
    save_freq = s.save_freq
    mode=3
    image_loss_weight=s.image_loss_weight
    epochs = s.epochs
    beta1,beta2=s.betas
    infinite_loop=s.infinite_loop
    data_path = s.data_path
    drop_rate = 0
    lab = s.lab
    load_list = s.load_list
    help='train_gan.py -b <batch size> -e <amount of epochs to train. standard: infinite> -r <report frequency> -w <path to weights folder> \
            -n <name> -s <save freq.> -l <learning rate> -p <path to data set> -d <dropout rate> -m <mode: differnet models> --beta1 <beta1 for adam>\
            --beta2 <beta2 for adam> --lab <No argument. If used lab colorspace is used> --weighted <No argument. If used *NO* class weights are used> \
            --lambda <hyperparameter for class weights>'
    try:
        opts, args = getopt.getopt(argv,"he:b:r:w:l:s:n:m:p:d:i:",
            ['epochs=',"mbsize=","report-freq=",'weight-path=', 'lr=','save-freq=','weight-name=','mode=','data_path=','drop_rate='
            'beta1=','beta2=','lab','image-loss-weight=','load-list'])
    except getopt.GetoptError:
        print(help)
        sys.exit(2)
    print("opts" ,opts)
    for opt, arg in opts:
        if opt == '-h':
            print(help)
            sys.exit()
        elif opt in ("-b", "--mbsize"):
            mbsize = int(arg)
        #elif opt in ("-p", "--data-path"):
        #    data_path = arg
        elif opt in ("-e", "--epochs"):
            epochs = int(arg)
            infinite_loop=False
        elif opt in ('-r','--report-freq'):
            report_freq = int(arg)
        elif opt in ("-w", "--weight-path"):
            weight_path = arg
        elif opt in ("-n", "--weight-name"):
            weights_name = arg            
        elif opt in ("-s", "--save-freq"):
            save_freq=int(arg)
        elif opt in ("-l", "--lr"):
            lr = float(arg)
        elif opt=='-m':
            if arg in ('custom','0'):
                mode = 0
            elif arg in ('u','1','unet'):
                mode = 1
            elif arg in ('ende','2'):
                mode = 2
            elif arg in ('mu','3','middle'):
                mode = 3
            elif arg in ('cu','4','middle'):
                mode = 4
        elif opt in ("-p", "--data_path"):
            data_path = str(arg)
        elif opt in ("-d", "--drop_rate"):
            drop_rate = float(arg)
        elif opt=='--beta1':
            beta1 = float(arg)
        elif opt=='--beta2':
            beta2 = float(arg)
        elif opt=='--lab':
            lab=True
        elif opt in ('-i','--image-loss-weight'):
            image_loss_weight=float(arg)
        elif opt in ('--load-list'):
            load_list=True

    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dataset=None
    if 'cifar' in data_path:
        in_size = 32
        dataset = 0
    elif 'places' in data_path:
        in_size = 224
        dataset = 1
    elif 'stl' in data_path:
        in_size = 96
        dataset = 2
    in_shape=(3,in_size,in_size)

    betas=(beta1,beta2)
    weight_path_ending=os.path.join(weight_path,weights_name+'.pth')

    loss_path_ending = os.path.join(weight_path, weights_name + "_" + s.loss_name)

    trainset = load_trainset(data_path,lab=lab,load_list=load_list)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=mbsize,
                                        shuffle=True, num_workers=2)
 
    print("NETWORK PATH:", weight_path_ending)
    #define output channels of the model
    classes=2 if lab else 3
    #define model
    UNet=None
    try:
        if mode ==0:
            UNet=model(col_channels=classes) 
        elif mode ==1:
            UNet=unet(drop_rate=drop_rate,classes=classes)
        elif mode ==2:
            UNet=generator(drop_rate,classes)
        elif mode ==3:
            UNet=middle_unet(drop_rate=drop_rate,classes=classes)
        elif mode ==4:
            UNet=color_unet(drop_rate=drop_rate,classes=classes)
        
        #load weights
        try:
            UNet.load_state_dict(torch.load(weight_path_ending))
            print("Loaded network weights from", weight_path)
        except FileNotFoundError:
            print("Initialize new weights for the generator.")

    except RuntimeError:
        #if the wrong mode was chosen: try the other one
        UNet=model(col_channels=classes) if mode==1 else unet(classes=classes)
        #load weights
        try:
            UNet.load_state_dict(torch.load(weight_path_ending))
            print("Loaded network weights from", weight_path)
            #change mode to the correct one
            mode = (mode +1) %2
        except FileNotFoundError:
            print("Initialize new weights for the generator.")
   
    UNet.to(device)

    #save the hyperparameters to a JSON-file for better oranization
    model_description_path_ending = os.path.join(weight_path, s.model_description_name)
    # initialize model dict
    try:
        with open(model_description_path_ending, "r") as file:
            model_dict = json.load(file)
    except FileNotFoundError:
        model_dict = {}


    prev_epochs=0
    # save settings in dict if new weights are beeing initialized
    if not weights_name in model_dict.keys():
        model_dict[weights_name] = {
            "loss_name": loss_path_ending,
            "epochs": 0,
            "batch_size": mbsize,
            "lr": lr,
            "lab":lab,
            "betas": betas,
            "image_loss_weight": image_loss_weight,
            "model":['custom','unet','encoder-decoder','middle-unet','color-unet'][mode]
        }
    else:
        #load specified parameters from model_dict
        params=model_dict[weights_name]
        mbsize=params['batch_size']
        betas=params['betas']
        lr=params['lr']
        lab=params['lab']
        image_loss_weight=params['image_loss_weight']
        loss_path_ending=params['loss_name']
        #memorize how many epochs already were trained if we continue training
        prev_epochs=params['epochs']+1

    
    #define critic 
    crit=critic(in_size,classes=classes).to(device)
    #load discriminator weights
    crit_path=os.path.join(weight_path,weights_name+'_crit.pth')
    try:
        crit.load_state_dict(torch.load(crit_path))
        print('Loaded weights for discriminator from %s'%crit_path)
    except FileNotFoundError:
        print('Initialize new weights for discriminator')
        crit.apply(weights_init_normal)
    #optimizer
    optimizer_g=optim.Adam(UNet.parameters(),lr=lr,betas=betas)
    optimizer_c=optim.Adam(crit.parameters(),lr=lr,betas=betas)
    criterion = nn.BCELoss().to(device)
    #additional gan loss: l1 loss
    l1loss = nn.L1Loss().to(device)
    loss_hist=[]

    

    UNet.train()
    crit.train()
    gray = torch.tensor([0.2989 ,0.5870, 0.1140 ])[:,None,None].float()
    ones = torch.ones(mbsize,device=device)
    zeros= torch.zeros(mbsize,device=device)
    # run over epochs
    for e in (range(prev_epochs, prev_epochs + epochs) if not infinite_loop else count(prev_epochs)):
        g_running,c_running=0,0
        #load batches          
        for i,batch in enumerate(trainloader):
            if dataset == 0: #cifar 10
                (image,_) = batch
            elif dataset in (1,2): #places
                image = batch
                
            batch_size=image.shape[0]
            X=None
            #differentiate between the two available color spaces RGB and Lab
            if lab:
                if dataset == 0: #cifar 10
                    image=np.transpose(image,(0,2,3,1))
                    image=np.transpose(color.rgb2lab(image),(0,3,1,2))
                    image=torch.from_numpy((image+np.array([-50,0,0])[None,:,None,None])).float()
                X=torch.unsqueeze(image[:,0,:,:],1).to(device) #set X to the Lightness of the image
                image=image[:,1:,:,:].to(device) #image is a and b channel
            else:
                #convert to grayscale image
                #using the matlab formula: 0.2989 * R + 0.5870 * G + 0.1140 * B and load data to gpu
                X=(image.clone()*gray).sum(1).to(device).view(-1,1,*in_shape[1:])
                image=image.float().to(device)
            #----------------------------------------------------------------------------------------
            ################################### Unet optimization ###################################
            #----------------------------------------------------------------------------------------
            #clear gradients
            optimizer_g.zero_grad()
            #generate colorized version with unet
            unet_col=None
            #print(X.shape,image.shape,classes)
            if mode==0:
                unet_col=UNet(torch.stack((X,X,X),1)[:,:,0,:,:])
            else:
                unet_col=UNet(X)
            #calculate loss as a function of how good the unet can fool the critic
            fooling_loss=criterion(crit(unet_col)[:,0], ones[:batch_size])
            #calculate how close the generated pictures are to the ground truth
            image_loss=l1loss(unet_col,image)
            #combine both losses and weight them
            loss_g=fooling_loss+image_loss_weight*image_loss
            #backpropagation
            loss_g.backward()
            optimizer_g.step()

            #----------------------------------------------------------------------------------------
            ################################## Critic optimization ##################################
            #----------------------------------------------------------------------------------------
            optimizer_c.zero_grad()
            real_loss=criterion(crit(image)[:,0],ones[:batch_size])
            #requires no gradient in unet col
            fake_loss=criterion(crit(unet_col.detach())[:,0],zeros[:batch_size])
            loss_c=.5*(real_loss+fake_loss)
            loss_c.backward()
            optimizer_c.step()

            g_running+=loss_g.item()
            c_running+=loss_c.item()
            loss_hist.append([e,i,loss_g.item(),loss_c.item()])

            #report running loss
            if (i+len(trainloader)*e)%report_freq==report_freq-1:
                print('Epoch %i, batch %i: \tunet loss=%.2e, \tcritic loss=%.2e'%(e+1,i+1,g_running/report_freq,c_running/report_freq))
                g_running=0
                c_running=0

            if s.save_weights and (i+len(trainloader)*e)%save_freq==save_freq-1:
                #save parameters
                try:
                    torch.save(UNet.state_dict(),weight_path_ending)
                    torch.save(crit.state_dict(),crit_path)
                except FileNotFoundError:
                    os.makedirs(weight_path)
                    torch.save(UNet.state_dict(),weight_path_ending)
                    torch.save(crit.state_dict(),crit_path)
                print("Parameters saved")

                if s.save_loss:
                    #save loss history to file
                    try:
                        f=open(loss_path_ending,'a')
                        np.savetxt(f,loss_hist,'%e')
                        f.close()
                    except FileNotFoundError:
                        os.makedirs(s.loss_path)
                        np.savetxt(loss_path_ending,loss_hist,'%e')
                    loss_hist=[]

        #update epoch count in dict after each epoch
        model_dict[weights_name]["epochs"] = e  
        #save it to file
        try:
            with open(model_description_path_ending, "w") as file:
                json.dump(model_dict, file, sort_keys=True, indent=4)
        except:
            print('Could not save to model dictionary (JSON-file)')        
Пример #5
0
def main(argv):

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    gray = torch.tensor([0.2989 ,0.5870, 0.1140])[:,None,None].float()

    alex = alexnet().to(device)
    alex.load_state_dict(torch.load('weights/alexnet.pth'))
    alex.eval()

    classifier_path = 'weights/alexNorm.pt'
    weight_path = None
    mbsize = 16
    col_space = None
    class_net = False
    mode = 2
    drop_rate = 0
    T = .4
    classes = 3
    stl_path = './stl-10'
    help='metric.py -b <batch size> -w <model weight file> -p <2layer classifer weight file> -s <colorspace "rgb" or "lab"> \
        -c <use if model is classification model> -m <mode> -d <dropout rate> -t <define temperature> --stl <path to stl dataset>
    try:
        opts, args = getopt.getopt(argv,"b:w:p:cm:d:t:s:",
            ["mbsize=",'weight-path=','classifier-path=','c_space=','mode=','drop_rate=','temperature=','stl='])
    except getopt.GetoptError:
        print(help)
        sys.exit(2)
    print("opts" ,opts)
    for opt, arg in opts:
        if opt in ("-b", "--mbsize"):
            mbsize = int(arg)
        elif opt in ("-w", "--weight-path"):
            weight_path = str(arg)
        elif opt in ("-p", "--classifier-path"):
            classifier_path = str(arg)
        elif opt in ("-s","--c_space"):
            col_space = str(arg)
        elif opt =='-c':
            class_net = True
        elif opt == '-m':
            if arg in ('u','0','unet'):
                mode = 0
            elif arg in ('color','1','cu'):
                mode = 1
            elif arg in ('mu','2','middle'):
                mode = 2
        elif opt in ("-d", "--drop_rate"):
            drop_rate = float(arg)
        elif opt in ('-t', '--temperature'):
            T = float(arg)
        elif opt in ('--stl'):
            stl_path=arg
    if col_space == None:
        print('Specify color space')
        sys.exit(2)

    if class_net:
        if col_space == 'yuv':
            classes = 42
        elif col_space == 'lab':
            classes = 150
    if mode == 0:
        Col_Net = unet(True, drop_rate, classes).to(device)
    elif mode == 1:
        Col_Net = color_unet(True, drop_rate, classes).to(device)
    elif mode == 2:
        Col_Net = middle_unet(True, drop_rate, classes).to(device)

    Col_Net.load_state_dict(torch.load(weight_path, map_location=device))
    Col_Net.eval()

    classifier=nn.Sequential(nn.Linear(1000,512),nn.ReLU(),nn.Linear(512,10)).to(device)
    classifier.load_state_dict(torch.load(classifier_path, map_location=device))
    classifier.eval()

    #we trained on the test set and evaluate on the train set since the test set has more labeled images than the trainset
    testset = datasets.STL10(stl_path,split='train',transform=transforms.Compose([transforms.ToTensor()]))
    testloader = torch.utils.data.DataLoader(testset, batch_size=mbsize, shuffle=True, num_workers=0)

    print(pseudo_metric(testloader, col_space, Col_Net, classifier, alex, T))