Esempio n. 1
0
def main(args):
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Image preprocessing, normalization for the pretrained resnet
    train_set = TrainImageFolder(args.image_dir, original_transform)

    # Build data loader
    data_loader = torch.utils.data.DataLoader(train_set, batch_size = args.batch_size, shuffle = True, num_workers = args.num_workers)

    # Build the models
    model=nn.DataParallel(Color_model()).cuda()
    #model.load_state_dict(torch.load('../model/models/model-171-216.ckpt'))
    encode_layer=NNEncLayer()
    boost_layer=PriorBoostLayer()
    nongray_mask=NonGrayMaskLayer()
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss(reduce=False).cuda()
    params = list(model.parameters())
    optimizer = torch.optim.Adam(params, lr = args.learning_rate)
    

    # Train the models
    total_step = len(data_loader)
    for epoch in range(args.num_epochs):
        try:
            for i, (images, img_ab) in enumerate(data_loader):
                try:
                    # Set mini-batch dataset
                    images = images.unsqueeze(1).float().cuda()
                    img_ab = img_ab.float()
                    encode,max_encode=encode_layer.forward(img_ab)
                    targets=torch.Tensor(max_encode).long().cuda()
                    #print('set_tar',set(targets[0].cpu().data.numpy().flatten()))
                    boost=torch.Tensor(boost_layer.forward(encode)).float().cuda()
                    mask=torch.Tensor(nongray_mask.forward(img_ab)).float().cuda()
                    boost_nongray=boost*mask
                    outputs = model(images)#.log()
                    output=outputs[0].cpu().data.numpy()
                    out_max=np.argmax(output,axis=0)
                    #print('set',set(out_max.flatten()))
                    loss = (criterion(outputs,targets)*(boost_nongray.squeeze(1))).mean()
                    #loss=criterion(outputs,targets)
                    model.zero_grad()
                
                    loss.backward()
                    optimizer.step()

                    # Print log info
                    if i % args.log_step == 0:
                        print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                          .format(epoch, args.num_epochs, i, total_step, loss.item()))

                    # Save the model checkpoints
                    if (i + 1) % args.save_step == 0:
                        torch.save(model.state_dict(), os.path.join(
                            args.model_path, 'model-{}-{}.ckpt'.format(epoch + 1, i + 1)))
                except:
                    pass
        except:
            pass
Esempio n. 2
0
def main(args):
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Image preprocessing, normalization for the pretrained resnet
    train_set = TrainImageFolder(args.img_list_path,down_rate=model_dict[args.model]['down_rate'])

    # Build data loader
    data_loader = torch.utils.data.DataLoader(train_set, batch_size = args.batch_size, shuffle = True, num_workers = args.num_workers)

    # Build the models
    model=nn.DataParallel(model_dict[args.model]['structure']()).cuda()

    if not args.train_from_scratch and os.path.exists(args.checkpoint_path):
        model.load_state_dict(torch.load(args.checkpoint_path))
        start_epoch=int(args.checkpoint_path.split(os.path.sep)[-1].split('.')[0].split('-')[1])
    else:
        start_epoch=1
    encode_layer=NNEncLayer()
    boost_layer=PriorBoostLayer()
    nongray_mask=NonGrayMaskLayer()
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss(reduce=False).cuda()
    params = list(model.parameters())
    optimizer = torch.optim.Adam(params, lr = args.learning_rate)
    

    # Train the models
    print("[Start Training]")
    total_step = len(data_loader)
    for epoch in range(start_epoch,args.num_epochs):
        print("-----Epoch {}/{}-----".format(epoch,args.num_epochs))
        start_time=time.time()
        error_count=0
        for i, (images, img_ab) in enumerate(data_loader):
            try:
                # Set mini-batch dataset
                images = images.unsqueeze(1).float().cuda()
                img_ab = img_ab.float()
                encode,max_encode=encode_layer.forward(img_ab)
                targets=torch.Tensor(max_encode).long().cuda()
                boost=torch.Tensor(boost_layer.forward(encode)).float().cuda()
                mask=torch.Tensor(nongray_mask.forward(img_ab)).float().cuda()
                boost_nongray=boost*mask
                outputs = model(images)#.log()
                output=outputs[0].cpu().data.numpy()
                out_max=np.argmax(output,axis=0)

                # print('set',set(out_max.flatten()))
                loss = (criterion(outputs,targets)*(boost_nongray.squeeze(1))).mean()
                #loss=criterion(outputs,targets)
                #multi=loss*boost_nongray.squeeze(1)

                model.zero_grad()

                loss.backward()
                optimizer.step()

                # Print log info
                if i % args.log_step == 0:
                    cost_time=time.time()-start_time
                    print('Epoch [{}/{}], Step [{}/{}], Loss:{:.4f}, error_count:{}, cost_time:{:.3f}'
                      .format(epoch, args.num_epochs, i, total_step, loss.item(),error_count, cost_time))
                    start_time=time.time()

                # Save the model checkpoints
                if (i + 1) % args.save_step == 0:
                    torch.save(model.state_dict(), os.path.join(
                        args.model_path, 'model-{}-{}.ckpt'.format(epoch, i + 1)))
            except:
                error_count+=1
                print("Epoch [{}/{}], Step [{}/{}] Error!,Error count:{}".format(
                    epoch, args.num_epochs, i, total_step,error_count
                ))
Esempio n. 3
0
def main(args):
    
    # Instance of the class that will preprocess and generate proper images for training
    train_set = CustomDataset(args.image_dir, original_transform)

    # Data loader that will generate the proper batches of images for training
    data_loader = torch.utils.data.DataLoader(train_set, batch_size = args.batch_size, shuffle = True, num_workers = args.num_workers)

    # Model instance whose architecture was configured in the 'model.py' file
    model = Color_model().cuda()
    # model.load_state_dict(torch.load(args.load_model))
    
    # Loss function used
    criterion = nn.CrossEntropyLoss().cuda()
    
    # Model parameters and optimizer used during the training step
    params = list(model.parameters())
    optimizer = torch.optim.Adam(params, lr = args.learning_rate)

    # Instance of 'NNEncLayer' class that is responsable for to return a probability distribution for each pixel of the (a,b) channels. This class is in 'training_layers.py' file 
    encode_ab_layer = NNEncLayer()
    # Instance of 'NonGrayMaskLayer'. Return 1 for not color images and 0 for grayscale images. 
    nongraymask_layer = NonGrayMaskLayer()
    # Instance of 'PriorBoostLayer'. Returns the weights for every color.
    priorboost_layer = PriorBoostLayer()
    # Instance of 'ClassRebalance'. Ponders the colors with weights trying to make rare colors to contribute with the model. 
    class_rebalance_layer = ClassRebalance.apply

    #####################################################################
    #----------------------->> TRAINING STEP <<-------------------------#
    #####################################################################
    print("====>> Training step started!! <<====")

    # Number of batches
    total_batch = len(data_loader)

    # Store the loss of every epoch for training dataset.
    running_loss_history = []

    # Start time to measure time of training
    start_time = time.time()
    
    # Main loop, loop for each epoch
    for epoch in range(args.num_epochs):        

        # Every loss per batch is summed to get the final loss for each epoch for training dataset.
        running_loss = 0.0    

        # Loop for each batch of images
        for i, (images, img_ab, filename) in enumerate(data_loader):
            #print(filename)
            
            # Grayscale images represented by L channel
            images = images.unsqueeze(1).float().cuda() # Unsqueeze(1) add one more dimension to the tensor in position 1, than converted to float and loaded to the GPU
            # Ground truth represented by (a,b) channels
            img_ab = img_ab.float() 

            # 'encode_ab' -> represents a probability distribution for each pixel of the (a,b) channels
            # 'max_encode_ab' -> represents the indexes that have the highest values of probability along each pixel layers
            encode_ab, max_encode_ab = encode_ab_layer.forward(img_ab)
            #encode_ab = torch.from_numpy(encode_ab).long().cuda()

            # 'max_encode_ab' is used as targets. So it is converted to long data type and then loaded to the GPU
            targets = torch.Tensor(max_encode_ab).long().cuda()

            nongray_mask = torch.Tensor(nongraymask_layer.forward(img_ab)).float().cuda()
            prior_boost = torch.Tensor(priorboost_layer.forward(encode_ab)).float().cuda()
            prior_boost_nongray = prior_boost * nongray_mask
            
            # The input grayscale images are submitted to the model and the result tensor with shape [Bx313xWxH] is stored in 'output'
            outputs = model(images)
            
            # Class Rebalance execution, pondering the gradients.
            outputs = class_rebalance_layer(outputs, prior_boost_nongray)

            # loss = (criterion(outputs,targets)*(prior_boost_nongray.squeeze(1))).mean()

            # The loss is performed for each batch(image)
            loss = criterion(outputs,targets)

            # Every loss per batch is summed to get the final loss for each epoch. 
            running_loss += loss.item() 
            
            model.zero_grad()            
            loss.backward()
            optimizer.step()

            # Print info about the training according to the log_step value
            if (i) % args.log_step == 0:
                print('Epoch [{}/{}], Batch [{}/{}]'.format(epoch+1, args.num_epochs, i+1, total_batch))

            # Save the model according to the checkpoints configured
            if epoch in args.checkpoint_step and i == (args.trainDataset_length/args.batch_size)-1:
                torch.save(model.state_dict(), os.path.join(args.model_path, 'model-{}-{}.ckpt'.format(epoch + 1, i + 1)))
            
        # Average Loss of an epoch for training dataset.
        epoch_loss = running_loss/len(data_loader) # Acumulated Loss divided by number of images batches on training dataset.
        running_loss_history.append(epoch_loss)
        #print("{:.2f} minutes".format((time.time() - start_time)/60))
        print('--------->>> Epoch [{}/{}], Epoch Loss: {:.4f}'.format(epoch+1, args.num_epochs, epoch_loss))
    
    print('Loss History: {}'.format(running_loss_history))
    print("{:.2f} minutes".format((time.time() - start_time)/60))
    print("                                                    ")

    plt.plot(np.arange(0,args.num_epochs), running_loss_history, label='Training Loss')
    
    ax = plt.gca()
    ax.set_facecolor((0.85, 0.85, 0.85))
    plt.grid(color='w', linestyle='solid')
    ax.set_axisbelow(True)

    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig(args.save_lossCurve)                   
Esempio n. 4
0
def main(args):
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Image preprocessing, normalization for the pretrained resnet
    train_set = TrainImageFolder(args.image_dir, original_transform)

    # Build data loader
    data_loader = torch.utils.data.DataLoader(train_set,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=args.num_workers)

    test_data = TrainImageFolder('../../data/custom/test/test/',
                                 transform=transforms.Compose([]))
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=args.num_workers)

    # Build the models
    model = Color_model()
    # model.load_state_dict(torch.load('../model/models/model-171-216.ckpt'))
    encode_layer = NNEncLayer()
    boost_layer = PriorBoostLayer()
    nongray_mask = NonGrayMaskLayer()
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss(reduce=False)
    params = list(model.parameters())
    optimizer = torch.optim.Adam(params, lr=args.learning_rate)

    # Train the models
    total_step = len(data_loader)
    start_time = time.time()
    for epoch in range(args.num_epochs):
        for i, (images, img_ab) in enumerate(data_loader):
            try:
                # Set mini-batch dataset
                images = images.unsqueeze(1).float()
                img_ab = img_ab.float()
                encode, max_encode = encode_layer.forward(img_ab)
                targets = torch.Tensor(max_encode).long()
                boost = torch.Tensor(boost_layer.forward(encode)).float()
                mask = torch.Tensor(nongray_mask.forward(img_ab)).float()
                boost_nongray = boost * mask
                outputs = model(images)  # .log()
                output = outputs[0].cpu().data.numpy()
                out_max = np.argmax(output, axis=0)

                print('set', set(out_max.flatten()))
                loss = (criterion(outputs, targets) *
                        (boost_nongray.squeeze(1))).mean()
                # loss=criterion(outputs,targets)
                # multi=loss*boost_nongray.squeeze(1)

                model.zero_grad()

                loss.backward()
                optimizer.step()

                # Print log info
                if i % args.log_step == 0:
                    print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
                        epoch + 1, args.num_epochs, i + 1, total_step,
                        loss.item()))

                # Save the model checkpoints
                if (i + 1) % args.save_step == 0:
                    torch.save(
                        model.state_dict(),
                        os.path.join(
                            args.model_path,
                            'model-{}-{}.ckpt'.format(epoch + 1, i + 1)))
            except:
                pass

    print(f"Total time: {time.time() - start_time}")

    with torch.no_grad():
        for b, (X_test, y_test) in enumerate(test_loader):
            X_test = X_test.unsqueeze(1).float()
            y_val = model(X_test)
            color_img = decode(X_test, y_val)
            color_name = '../data/colorimg/' + str(b + 1) + '.jpeg'
            imageio.imsave(color_name, color_img * 255.)
Esempio n. 5
0
    def __init__(self):
        super(ColorfulImageColorizationModel,
              self).__init__(data_ab_ss=L.Convolution2D(None,
                                                        2,
                                                        ksize=1,
                                                        stride=4),
                             conv1_1=L.Convolution2D(None, 64, ksize=3, pad=1),
                             conv1_2=L.Convolution2D(None,
                                                     64,
                                                     ksize=3,
                                                     pad=1,
                                                     stride=2),
                             conv1_2norm=L.BatchNormalization(64),
                             conv2_1=L.Convolution2D(None, 128, ksize=3,
                                                     pad=1),
                             conv2_2=L.Convolution2D(None,
                                                     128,
                                                     ksize=3,
                                                     pad=1,
                                                     stride=2),
                             conv2_2norm=L.BatchNormalization(128),
                             conv3_1=L.Convolution2D(None, 256, ksize=3,
                                                     pad=1),
                             conv3_2=L.Convolution2D(None, 256, ksize=3,
                                                     pad=1),
                             conv3_3=L.Convolution2D(None,
                                                     256,
                                                     ksize=3,
                                                     pad=1,
                                                     stride=2),
                             conv3_3norm=L.BatchNormalization(256),
                             conv4_1=L.DilatedConvolution2D(256,
                                                            512,
                                                            ksize=3,
                                                            pad=1,
                                                            dilate=1),
                             conv4_2=L.DilatedConvolution2D(512,
                                                            512,
                                                            ksize=3,
                                                            pad=1,
                                                            dilate=1),
                             conv4_3=L.DilatedConvolution2D(512,
                                                            512,
                                                            ksize=3,
                                                            pad=1,
                                                            dilate=1),
                             conv4_3norm=L.BatchNormalization(512),
                             conv5_1=L.DilatedConvolution2D(512,
                                                            512,
                                                            ksize=3,
                                                            pad=2,
                                                            dilate=2),
                             conv5_2=L.DilatedConvolution2D(512,
                                                            512,
                                                            ksize=3,
                                                            pad=2,
                                                            dilate=2),
                             conv5_3=L.DilatedConvolution2D(512,
                                                            512,
                                                            ksize=3,
                                                            pad=2,
                                                            dilate=2),
                             conv5_3norm=L.BatchNormalization(512),
                             conv6_1=L.DilatedConvolution2D(512,
                                                            512,
                                                            ksize=3,
                                                            pad=2,
                                                            dilate=2),
                             conv6_2=L.DilatedConvolution2D(512,
                                                            512,
                                                            ksize=3,
                                                            pad=2,
                                                            dilate=2),
                             conv6_3=L.DilatedConvolution2D(512,
                                                            512,
                                                            ksize=3,
                                                            pad=2,
                                                            dilate=2),
                             conv6_3norm=L.BatchNormalization(512),
                             conv7_1=L.DilatedConvolution2D(512,
                                                            512,
                                                            ksize=3,
                                                            pad=1,
                                                            dilate=1),
                             conv7_2=L.DilatedConvolution2D(512,
                                                            512,
                                                            ksize=3,
                                                            pad=1,
                                                            dilate=1),
                             conv7_3=L.DilatedConvolution2D(512,
                                                            512,
                                                            ksize=3,
                                                            pad=1,
                                                            dilate=1),
                             conv7_3norm=L.BatchNormalization(512),
                             conv8_1=L.Deconvolution2D(512,
                                                       256,
                                                       ksize=4,
                                                       pad=1,
                                                       stride=2),
                             conv8_2=L.DilatedConvolution2D(256,
                                                            256,
                                                            ksize=3,
                                                            pad=1,
                                                            dilate=1),
                             conv8_3=L.DilatedConvolution2D(256,
                                                            256,
                                                            ksize=3,
                                                            pad=1,
                                                            dilate=1),
                             conv313=L.DilatedConvolution2D(256,
                                                            313,
                                                            ksize=1,
                                                            dilate=1))

        self.prior_boost_layer = PriorBoostLayer()
        self.nn_enc_layer = NNEncLayer()
        # self.class_reblance_layer = ClassRebalanceMultLayer()
        self.non_gray_mask_layer = NonGrayMaskLayer()