Example #1
0
 def __init__(self, model):
     self.model = model
     self.hist_loss = hl.History()
     self.hist_fig = hl.History()
     self.canvas_hm = hl.Canvas()
     self.canvas_paf = hl.Canvas()
     self._step = (0, 0)
Example #2
0
    def __init__(self, opt):
        self.opt = opt
        self.cuda = opt.cuda
        self.is_train = opt.is_train
        self.device = torch.device(
            'cuda:{}'.format(self.cuda[0]) if self.cuda else 'cpu')
        self.save_dir = osp.join(opt.ckpt_root, opt.name)
        self.optimizer = None
        self.loss = None

        # init mesh data
        self.nclasses = opt.nclasses

        # init network
        self.net = networks.get_net(opt)
        self.net.train(self.is_train)

        # criterion
        self.loss = networks.get_loss(self.opt).to(self.device)

        if self.is_train:
            # self.optimizer = adabound.AdaBound(
            #     params=self.net.parameters(), lr=self.opt.lr, final_lr=self.opt.final_lr)
            self.optimizer = optim.SGD(self.net.parameters(),
                                       lr=opt.lr,
                                       momentum=opt.momentum,
                                       weight_decay=opt.weight_decay)
            self.scheduler = networks.get_scheduler(self.optimizer, self.opt)
        if not self.is_train or opt.continue_train:
            self.load_state(opt.last_epoch)

        # A History object to store metrics
        self.history = hl.History()
        # A Canvas object to draw the metrics
        self.canvas = hl.Canvas()
Example #3
0
    def test_train(self):
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        train_loader = torch.utils.data.DataLoader(datasets.MNIST(
            DATA_DIR,
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ])),
                                                   batch_size=64,
                                                   shuffle=True)
        test_loader = torch.utils.data.DataLoader(datasets.MNIST(
            DATA_DIR,
            train=False,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ])),
                                                  batch_size=1000,
                                                  shuffle=True)

        model = Net().to(device)
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

        # Create History object
        model.history = hl.History()
        model.canvas = hl.Canvas()

        for epoch in range(1, 3):
            train(model, device, train_loader, optimizer, epoch)
            test(model, device, test_loader)
Example #4
0
    def __init__(self, config, train_loader, valid_loader, test_loader):

        # Data loader
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.test_loader = test_loader

        # Models
        self.unet = None
        self.optimizer = None
        self.img_ch = config.img_ch
        self.output_ch = config.output_ch
        self.bce_loss = torch.nn.BCELoss()
        self.augmentation_prob = config.augmentation_prob

        # Hyper-parameters
        self.lr = config.lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.lamda = config.lamda

        # Training settings
        self.num_epochs = config.num_epochs
        self.num_epochs_decay = config.num_epochs_decay
        self.batch_size = config.batch_size
        self.save_model = config.save_model

        # Plots
        self.loss_history = hl.History()
        self.acc_history = hl.History()
        self.dc_history = hl.History()
        self.canvas = hl.Canvas()

        # Step size for plotting
        self.log_step = config.log_step
        self.val_step = config.val_step

        # Paths
        self.model_path = config.model_path
        self.result_path = config.result_path
        self.mode = config.mode

        # Model training properties
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.model_type = config.model_type
        self.t = config.t
        self.build_model()
Example #5
0
    def __init__(self, opt):
        self.opt = opt
        self.name = opt.name
        self.save_path = os.path.join(opt.ckpt_root, opt.name)
        self.train_loss = os.path.join(self.save_path, 'train_loss.txt')
        self.test_loss = os.path.join(self.save_path, 'test_loss.txt')

        # set display
        if opt.is_train and SummaryWriter is not None:
            self.display = SummaryWriter()  # comment=opt.name
        else:
            self.display = None

        self.start_logs()
        self.nexamples = 0
        self.ncorrect = 0

        # A History object to store metrics
        self.history = hl.History()

        # A Canvas object to draw the metrics
        self.canvas = hl.Canvas()
Example #6
0
    def __init__(self, train_iteration, val_iteration, _plot = False, _enable_full_plot = True):
        self._plot = _plot
        self._enable_full_plot = _enable_full_plot
        self.history = hl.History()
        self.text_table = Texttable(max_width = 0)  #unlimited
        self.text_table.set_precision(4)
        self.train_timer = Timer(train_iteration)
        if val_iteration:
            self.val_timer = Timer(val_iteration)
        self.redis = redis.Redis()
        self.redis.set('progress', 0)
        self.redis.set('desc', '')
        self.redis.set('stage', 'stop')
        self.redis.set(
            'history',
            pickle.dumps({
                'train_loss': [],
                'train_acc': [],
                'val_loss': [],
                'val_acc': []
            })
        )

        # Label + Pregress
        self.train_progress = Output()
        self.train_progress_label = Output()
        self.val_progress = Output()
        self.val_progress_label = Output()
        display(self.train_progress_label)
        display(self.train_progress)
        display(self.val_progress_label)
        display(self.val_progress)
        # Train
        with self.train_progress:
            self.train_progress_bar = IntProgress(bar_style = 'info')
            self.train_progress_bar.min = 0
            self.train_progress_bar.max = train_iteration
            display(self.train_progress_bar)
        with self.train_progress_label:
            self.train_progress_label_text = Label(value = "Initialization")
            display(self.train_progress_label_text)
        # Validate
        with self.val_progress:
            self.val_progress_bar = IntProgress(bar_style = 'warning')
            self.val_progress_bar.min = 0
            self.val_progress_bar.max = 1 if val_iteration is None else val_iteration
            display(self.val_progress_bar)
        with self.val_progress_label:
            self.val_progress_label_text = Label(value = "Initialization")
            display(self.val_progress_label_text)

        # Plots
        if self._plot:
            # 4 chartplots
            self.loss_plot = Output()
            self.matrix_plot = Output()
            display(HBox([self.loss_plot, self.matrix_plot]))
            self.lr_plot = Output()
            self.norm_plot = Output()
            display(HBox([self.lr_plot, self.norm_plot]))

            # Canvas
            self.loss_canvas = hl.Canvas()
            self.matrix_canvas = hl.Canvas()
            self.lr_canvas = hl.Canvas()
            self.norm_canvas = hl.Canvas()

        # Memory
        if self._enable_full_plot:
            gpu_count = nvmlDeviceGetCount()
            total_bars = [Output() for _ in range(2 * gpu_count)]
            self.gpu_mem_monitor = total_bars[::2]
            self.gpu_utils_monitor = total_bars[1::2]
            display(HBox(total_bars))
            self.gpu_mem_monitor_bar = []
            self.gpu_utils_monitor_bar = []
            for i, (membar, utilsbar) in enumerate(zip(self.gpu_mem_monitor, self.gpu_utils_monitor)):
                with membar:
                    self.gpu_mem_monitor_bar.append(
                        IntProgress(orientation = 'vertical', bar_style = 'success')
                    )
                    self.gpu_mem_monitor_bar[-1].description = 'M' + str(i) + ': 0%'
                    self.gpu_mem_monitor_bar[-1].min = 0
                    self.gpu_mem_monitor_bar[-1].max = 100
                    display(self.gpu_mem_monitor_bar[-1])

                with utilsbar:
                    self.gpu_utils_monitor_bar.append(
                        IntProgress(orientation = 'vertical', bar_style = 'success')
                    )
                    self.gpu_utils_monitor_bar[-1].description = 'U' + str(i) + ': 0%'
                    self.gpu_utils_monitor_bar[-1].min = 0
                    self.gpu_utils_monitor_bar[-1].max = 100
                    display(self.gpu_utils_monitor_bar[-1])


            # Customize
            self.custom_train_output = Output()
            self.custom_val_output = Output()
            display(HBox([self.custom_train_output, self.custom_val_output]))

            # Log
            self.text_log = Output()
            display(self.text_log)


            # Start monitor thread
            global _STOP_GPU_MONITOR_
            _STOP_GPU_MONITOR_ = False
            self.thread = threading.Thread(
                target = _gpu_monitor_worker,
                args = (self.gpu_mem_monitor_bar, self.gpu_utils_monitor_bar)
            )
            self.thread.start()
Example #7
0
from lib.db import connect
from lib.data.loader import LoanPerformanceDataset
from lib.enums import PRE_PROCESSING_ENCODERS_PICKLE_PATH, LIVE_PRE_PROCESSING_ENCODERS_PICKLE_PATH

import numpy as np
import torch.nn as nn

import hiddenlayer as hl

# Visualization Loading
#-- One Chart --#
# A History object to store metrics
history1 = hl.History()

# A Canvas object to draw the metrics
canvas1 = hl.Canvas()

LOCAL = False
USE_LIVE_PRE_PROCESSORS = not LOCAL
CHUNK_SIZE = 100  # fails on chunks 250/500/750 (GPU limit over 750)
LOADER_ARGS = dict(
    batch_size=
    2,  # size of batches from the query, 1 === size of query, 2 = 1/2 query size
    num_workers=1,
    shuffle=True)

dataset = LoanPerformanceDataset(
    chunk=CHUNK_SIZE,  # size of the query (use a large number here)
    conn=connect(local=LOCAL).connect(
    ),  # connect to local or remote database (docker, google cloud)
    ignore_headers=['loan_id'],
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Network definition
net = vo.OSVOS(pretrained=0)
net.load_state_dict(torch.load(os.path.join(save_dir, parentModelName+'_epoch-'+str(parentEpoch-1)+'.pth'),
                               map_location=lambda storage, loc: storage))

# Logging into Tensorboard
ROOT_LOG_DIR = "/content/drive/My Drive/MP"
TENSORBOARD_DIR = "qXX" # Sub-Directory for storing this specific experiment's logs
writer = SummaryWriter(os.path.join(ROOT_LOG_DIR, TENSORBOARD_DIR))

net.to(device)  # PyTorch 0.4.0 style

# Visualize the network
c = hl.Canvas()
if vis_net:
    x = torch.randn(1, 3, 480, 854)
    x.requires_grad_()
    x = x.to(device)

    # hl.build_graph(net, x)

    # g = hl.build_graph(net, x)
    # g.save( "pytorch_osvos.pdf")

    y = net.forward(x)
    g = viz.make_dot(y, net.state_dict())
    g.view()

    # debug()
Example #9
0
# save the graphic

# im.save(path="path_of_file/name_of_file" , format="jpg") # correct pathing

# -- scipy issue
# ImportError: No module named 'scipy._lib.decorator'

# -- uninstalled and reinstalled scipy and issue remains

#-- One Chart --#
# A History object to store metrics
history1 = hl.History()

# A Canvas object to draw the metrics
canvas1 = hl.Canvas()

# Simulate a training loop with two metrics: loss and accuracy
loss = 1
accuracy = 0
for step in range(800):
    # Fake loss and accuracy
    loss -= loss * np.random.uniform(-.09, 0.1)
    accuracy = max(0, accuracy + (1 - accuracy) * np.random.uniform(-.09, 0.1))

    # Log metrics and display them at certain intervals
    if step % 10 == 0:
        # Store metrics in the history object
        history1.log(step, loss=loss, accuracy=accuracy)

        # Plot the two metrics in one graph
Example #10
0
    def train(self,
              features,
              classifier,
              device,
              epochs,
              data_original,
              batch_size=32,
              round=-1,
              por=0.1):
        # def train(self, classifier, epochs, data_original, batch_size=32, round=-1 , por=0.1 ):
        c = h1.Canvas()
        hidden = h1.build_graph(self.D, torch.zeros([1, 542]).cuda())
        hidden.save('./discriminator.png')

        # hidden = h1.build_graph(self.G, torch.zeros([1, 642]).cuda())
        # hidden.save('./generator.png')

        (xmal, ymal), (xben, yben), (xtsmal, xtsben), (ytsmal, ytsben) = data_original[0], data_original[1], \
                                                                         data_original[2], data_original[3]

        xtrain = np.concatenate([xben, xmal])
        ytrain = np.concatenate([yben, ymal])

        #sampling for unbalanced data
        train_loader = sampling(xtrain, ytrain, batch_size)

        # classifier = train_target_model(classifier, train_loader, 100)

        print('\nTRAINING GAN...\n')
        print('=============data============= ')
        print('number of original malware = {0}\n'
              'number of original ben = {1}\n'
              'number of original test = {2}\n'.format(xmal.shape[0],
                                                       xben.shape[0],
                                                       xtsmal.shape[0]))
        print('==============================\n ')

        start_train = timer()
        Train_FNR, Test_FNR = [], []
        best_test_FNR, best_train_FNR = 0.0, 0.0

        self.gloss, self.dloss = [], []
        list_of_added_features = []
        g_grad_changes = []
        lb = preprocessing.LabelBinarizer()
        lb.fit([0, 1])

        for epoch in range(epochs):
            list_of_distortion = []
            batch_added_features = []
            for local_batch, local_lable in train_loader:
                # for step in range(xtrain.shape[0] // batch_size):
                xmal_batch = check_lenght(
                    local_batch[(local_lable != 0).nonzero()])
                xben_batch = check_lenght(
                    local_batch[(local_lable == 0).nonzero()])

                # xmal_batch = check_lenght(xtrain[(ytrain != 0).nonzero()])
                # xben_batch = check_lenght(xtrain[(ytrain == 0).nonzero()])
                # xmal_batch = torch.from_numpy(xmal_batch)

                noise = torch.rand(xmal_batch.shape[0], self.noise_size)
                yclassifierben_batch = classifier.predict(xben_batch)
                # yclassifierben_batch = classifier(xben_batch.float().cuda())

                # xben_batch= torch.from_numpy(xben_batch)

                # Generate a batch of new malware examples
                self.D.zero_grad()
                # gen_examples =binarize(self.G([xmal_batch.float().cuda(), noise.float().cuda()]).detach()))
                gen_examples = self.G(
                    [xmal_batch.float().cuda(),
                     noise.float().cuda()]).detach()

                # Check what features are added in this batch and add it to the list
                batch_added_features.append(
                    check_added_features(
                        binarize(gen_examples) - xmal_batch.cuda().float(),
                        features))

                # ---------------------
                #  Train Discriminator
                # ---------------------
                d_fake_decision = self.D(gen_examples)
                gen_examples_classifier_lable = classifier.predict(
                    binarize(gen_examples).cpu().detach().numpy())
                fake_data_lable = torch.cat(
                    (torch.zeros(gen_examples_classifier_lable.shape[0],
                                 1).cuda(),
                     torch.ones(gen_examples_classifier_lable.shape[0],
                                1).cuda()), 1)
                # fake_data_lable = torch.cat((torch.zeros(d_fake_decision.shape[0] ,1).cuda() , torch.ones(d_fake_decision.shape[0],1).cuda()) , 1)

                d_loss_fake = self.criterion(d_fake_decision, fake_data_lable)
                d_real_decision = self.D(xben_batch.float().cuda())
                # the disscriminator is using classifier predictions lables

                real_lable = np.hstack((1 - lb.transform(yclassifierben_batch),
                                        lb.transform(yclassifierben_batch)))
                # the disscriminator is just using true lables
                # real_lable = torch.cat((torch.ones(d_real_decision.shape[0], 1).cuda(), torch.zeros(d_real_decision.shape[0], 1).cuda()),
                #       1)
                # d_loss_real= self.criterion(d_real_decision,real_lable)
                d_loss_real = self.criterion(
                    d_real_decision,
                    torch.from_numpy(real_lable).float().cuda())
                # d_loss_real = self.criterion(d_real_decision,yclassifierben_batch.detach())
                # d_loss_real = self.criterion(d_real_decision,torch.zeros_like(d_real_decision))
                d_loss = 0.5 * torch.add(d_loss_real, d_loss_fake)
                d_loss_real.backward(retain_graph=True)
                d_loss_fake.backward(retain_graph=True)
                # plot_grad_flow(self.D.named_parameters())

                self.d_optimizer.step()

                # ---------------------
                #  Train Generator
                # ---------------------

                noise = torch.rand(xmal_batch.shape[0], self.noise_size)
                self.G.zero_grad()
                g_fake_data = self.G(
                    [xmal_batch.float().cuda(),
                     noise.float().cuda()])
                dg_fake_decision = self.D(g_fake_data)
                g_desired_lable = torch.cat(
                    (torch.ones(d_fake_decision.shape[0], 1).cuda(),
                     torch.zeros(d_fake_decision.shape[0], 1).cuda()), 1)
                # g_loss_samples = nn.functional.mse_loss(dg_fake_decision, g_desired_lable)
                g_loss_samples = self.criterion(dg_fake_decision,
                                                g_desired_lable)
                # g_loss_samples.reqiers_grad = True
                if self.losstype == 'limited_distortion':
                    orig_adv_dist = torch.sum(
                        g_fake_data - xmal_batch.float().cuda(), 1)
                    g_loss = g_loss_samples + 0.001 * torch.mean(
                        torch.norm(g_fake_data - xmal_batch.float().cuda(), 1,
                                   1))
                    list_of_distortion.append(
                        orig_adv_dist.cpu().detach().numpy())
                # if self.losstype == 'unlimited_distorion':
                # g_loss = g_loss_samples

                # g_loss_samples.backward(retain_graph=True)
                g_loss.backward(retain_graph=True)
                plot_grad_flow(self.G.named_parameters())
                g_grad_changes.append(grad_changes(self.G.named_parameters()))
                self.g_optimizer.step()

                torch.cuda.empty_cache()

            list_of_added_features.append(
                pd.DataFrame(batch_added_features).mean(axis=0).to_dict())
            self.gloss.append(g_loss)
            self.dloss.append(d_loss)

            # Compute attack success rate on train and test data
            train_FNR = self.FNR(xmal, classifier)
            test_FNR = self.FNR(xtsmal, classifier)

            if train_FNR > best_train_FNR:
                best_train_FNR = train_FNR
            Train_FNR.append(train_FNR)
            Test_FNR.append(test_FNR)
            if test_FNR > best_test_FNR:
                best_test_FNR = test_FNR
                print('saving mulgan weights at epoch: %d', epoch)

                # print("[D loss: %f]  [G loss: %f] \n" % ( self.dloss[-1],self.gloss[-1]))
                torch.save(
                    self.G,
                    '/home/maryam/Code/python/adversarial_training/torch_impl/_'
                    + str(por) + '/malgan{0}.pt'.format(round))

            print(
                "%d [D loss: %f] [G loss: %f] [train_FNR: %f] [test_FNR: %f] [distortion: %f]"
                % (epoch, d_loss, g_loss, train_FNR, test_FNR,
                   pd.DataFrame(batch_added_features).mean(axis=0).sum()))
        end_train = timer()
        del g_loss, d_loss_real, d_loss_fake  # d_loss

        print('\ntraining completed in %.3f seconds.\n' %
              (end_train - start_train))
        print('=============results ============= ')

        print(' attack success rate using train data: {0} \n'
              ' attack success rate using test data: {1}'.format(
                  best_train_FNR, best_test_FNR))
        print('==============================\n ')

        plot_added_featues(list_of_added_features)
        plot_added_featues(g_grad_changes)

        # Plot losses
        plt.figure()
        plt.plot(range(len(self.gloss)),
                 self.gloss,
                 c='r',
                 label='g_loss_rec',
                 linewidth=2)
        plt.plot(range(len(self.dloss)),
                 self.dloss,
                 c='g',
                 linestyle='--',
                 label='d_loss',
                 linewidth=2)
        plt.xlabel('Epoch')
        plt.ylabel('loss')
        plt.legend()
        plt.savefig(
            '/home/maryam/Code/python/adversarial_training/torch_impl/_' +
            str(por) + '/GAN_Epoch_loss({0}).png'.format(round))
        plt.show()
        plt.close()

        # Plot TPR
        plt.figure()
        plt.plot(range(len(Train_FNR)),
                 Train_FNR,
                 c='r',
                 label='Training Set',
                 linewidth=2)
        # plt.plot(range(len(Train_FNR)), Train_FNR, c='r', label='Attack success rate', linewidth=2)
        plt.plot(range(len(Test_FNR)),
                 Test_FNR,
                 c='g',
                 linestyle='--',
                 label='Test Set',
                 linewidth=2)
        plt.xlabel('Epoch')
        plt.ylabel('FNR')
        plt.legend()
        plt.savefig(
            '/home/maryam/Code/python/adversarial_training/torch_impl/_' +
            str(por) + '/Epoch_FNR({0}).png'.format(round))
        plt.show()
        plt.close()

        return [best_train_FNR, best_test_FNR, (end_train - start_train)]
Example #11
0
    def train(self, args):
        # Image transforms
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize((args.load_height, args.load_width)),
            transforms.RandomCrop((args.crop_height, args.crop_width)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        dataset_dirs = utils.get_traindata_link(args.dataset_dir)

        # Initialize dataloader
        a_loader = torch.utils.data.DataLoader(dsets.ImageFolder(
            dataset_dirs['trainA'], transform=transform),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)
        b_loader = torch.utils.data.DataLoader(dsets.ImageFolder(
            dataset_dirs['trainB'], transform=transform),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)

        a_fake_sample = utils.Sample_from_Pool()
        b_fake_sample = utils.Sample_from_Pool()

        # live plot loss
        Gab_history = hl.History()
        Gba_history = hl.History()
        gan_history = hl.History()
        Da_history = hl.History()
        Db_history = hl.History()

        canvas = hl.Canvas()

        for epoch in range(self.start_epoch, args.epochs):
            lr = self.g_optimizer.param_groups[0]['lr']
            print('learning rate = %.7f' % lr)

            for i, (a_real, b_real) in enumerate(zip(a_loader, b_loader)):

                # Identify step
                step = epoch * min(len(a_loader), len(b_loader)) + i + 1

                # Generators ===============================================================
                # Turning off grads for discriminators
                set_grad([self.Da, self.Db], False)

                # Zero out grads of the generator
                self.g_optimizer.zero_grad()

                # Real images from sets A and B
                a_real = Variable(a_real[0])
                b_real = Variable(b_real[0])
                a_real, b_real = utils.cuda([a_real, b_real])

                # Passing through generators
                # Nomenclature. a_fake is fake image generated from b_real in the domain A.
                # NOTE: Gab generate a from b and vice versa
                a_fake = self.Gab(b_real)
                b_fake = self.Gba(a_real)

                a_recon = self.Gab(b_fake)
                b_recon = self.Gba(a_fake)

                # Both generators should be able to generate the image in its own domain
                # give an input from its own domain
                a_idt = self.Gab(a_real)
                b_idt = self.Gba(b_real)

                # Identity loss
                a_idt_loss = self.L1(a_idt, a_real) * args.delta
                b_idt_loss = self.L1(b_idt, b_real) * args.delta

                # Adverserial loss
                # Da return 1 for an image in domain A
                a_fake_dis = self.Da(a_fake)
                b_fake_dis = self.Db(b_fake)

                # Label expected here is 1 to fool the discriminator
                expected_label_a = utils.cuda(
                    Variable(torch.ones(a_fake_dis.size())))
                expected_label_b = utils.cuda(
                    Variable(torch.ones(b_fake_dis.size())))

                a_gen_loss = self.MSE(a_fake_dis, expected_label_a)
                b_gen_loss = self.MSE(b_fake_dis, expected_label_b)

                # Cycle Consistency loss
                a_cycle_loss = self.L1(a_recon, a_real) * args.alpha
                b_cycle_loss = self.L1(b_recon, b_real) * args.alpha

                # Structural Cycle Consistency loss
                a_scyc_loss = self.ssim(a_recon, a_real) * args.beta
                b_scyc_loss = self.ssim(b_recon, b_real) * args.beta

                # Structure similarity loss
                # ba refers to the ssim scores between input and output generated by gen_ba
                # the gray image values range is 0-1
                gray = kornia.color.RgbToGrayscale()
                a_real_gray = gray((a_real + 1) / 2.0)
                a_fake_gray = gray((a_fake + 1) / 2.0)
                a_recon_gray = gray((a_recon + 1) / 2.0)
                b_real_gray = gray((b_real + 1) / 2.0)
                b_fake_gray = gray((b_fake + 1) / 2.0)
                b_recon_gray = gray((b_recon + 1) / 2.0)

                ba_ssim_loss = (
                    (self.ssim(a_real_gray, b_fake_gray)) +
                    (self.ssim(a_fake_gray, b_recon_gray))) * args.gamma
                ab_ssim_loss = (
                    (self.ssim(b_real_gray, a_fake_gray)) +
                    (self.ssim(b_fake_gray, a_recon_gray))) * args.gamma

                # Total Generator Loss
                gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_scyc_loss + b_scyc_loss + a_idt_loss + b_idt_loss + ba_ssim_loss + ab_ssim_loss

                # Update Generators
                gen_loss.backward()
                self.g_optimizer.step()

                # Discriminators ===========================================================
                # Turn on grads for discriminators
                set_grad([self.Da, self.Db], True)
                self.d_optimizer.zero_grad()

                # Sample from previously generated fake images
                a_fake = Variable(
                    torch.Tensor(
                        a_fake_sample([a_fake.cpu().data.numpy()])[0]))
                b_fake = Variable(
                    torch.Tensor(
                        b_fake_sample([b_fake.cpu().data.numpy()])[0]))
                a_fake, b_fake = utils.cuda([a_fake, b_fake])

                # Pass through discriminators
                # Discriminator for domain A
                a_real_dis = self.Da(a_real)
                a_fake_dis = self.Da(a_fake)

                # Discriminator for domain B
                b_real_dis = self.Db(b_real)
                b_fake_dis = self.Db(b_fake)

                # Expected label for real image is 1
                exp_real_label_a = utils.cuda(
                    Variable(torch.ones(a_real_dis.size())))
                exp_fake_label_a = utils.cuda(
                    Variable(torch.zeros(a_fake_dis.size())))

                exp_real_label_b = utils.cuda(
                    Variable(torch.ones(b_real_dis.size())))
                exp_fake_label_b = utils.cuda(
                    Variable(torch.zeros(b_fake_dis.size())))

                # Discriminator losses
                a_real_dis_loss = self.MSE(a_real_dis, exp_real_label_a)
                a_fake_dis_loss = self.MSE(a_fake_dis, exp_fake_label_a)
                b_real_dis_loss = self.MSE(b_real_dis, exp_real_label_b)
                b_fake_dis_loss = self.MSE(b_fake_dis, exp_fake_label_b)

                # Total discriminator loss
                a_dis_loss = (a_fake_dis_loss + a_real_dis_loss) / 2
                b_dis_loss = (b_fake_dis_loss + b_real_dis_loss) / 2

                # Update discriminators
                a_dis_loss.backward()
                b_dis_loss.backward()

                self.d_optimizer.step()

                if i % args.log_freq == 0:
                    # Log losses
                    Gab_history.log(step,
                                    gen_loss=a_gen_loss,
                                    cycle_loss=a_cycle_loss,
                                    idt_loss=a_idt_loss,
                                    ssim_loss=ab_ssim_loss,
                                    scyc_loss=a_scyc_loss)

                    Gba_history.log(step,
                                    gen_loss=b_gen_loss,
                                    cycle_loss=b_cycle_loss,
                                    idt_loss=b_idt_loss,
                                    ssim_loss=ba_ssim_loss,
                                    scyc_loss=b_scyc_loss)

                    Da_history.log(step,
                                   loss=a_dis_loss,
                                   fake_loss=a_fake_dis_loss,
                                   real_loss=a_real_dis_loss)

                    Db_history.log(step,
                                   loss=b_dis_loss,
                                   fake_loss=b_fake_dis_loss,
                                   real_loss=b_real_dis_loss)

                    gan_history.log(step,
                                    gen_loss=gen_loss,
                                    dis_loss=(a_dis_loss + b_dis_loss))

                    print(
                        "Epoch: (%3d) (%5d/%5d) | Gen Loss:%.2e | Dis Loss:%.2e"
                        % (epoch, i + 1, min(len(a_loader), len(b_loader)),
                           gen_loss, a_dis_loss + b_dis_loss))
                    with canvas:
                        canvas.draw_plot([
                            Gba_history['gen_loss'], Gba_history['cycle_loss'],
                            Gba_history['idt_loss'], Gba_history['ssim_loss'],
                            Gba_history['scyc_loss']
                        ],
                                         labels=[
                                             'Adv loss', 'Cycle loss',
                                             'Identity loss', 'SSIM',
                                             'SCyC loss'
                                         ])

                        canvas.draw_plot([
                            Gab_history['gen_loss'], Gab_history['cycle_loss'],
                            Gab_history['idt_loss'], Gab_history['ssim_loss'],
                            Gab_history['scyc_loss']
                        ],
                                         labels=[
                                             'Adv loss', 'Cycle loss',
                                             'Identity loss', 'SSIM',
                                             'SCyC loss'
                                         ])

                        canvas.draw_plot(
                            [
                                Db_history['loss'], Db_history['fake_loss'],
                                Db_history['real_loss']
                            ],
                            labels=['Loss', 'Fake Loss', 'Real Loss'])

                        canvas.draw_plot(
                            [
                                Da_history['loss'], Da_history['fake_loss'],
                                Da_history['real_loss']
                            ],
                            labels=['Loss', 'Fake Loss', 'Real Loss'])

                        canvas.draw_plot(
                            [gan_history['gen_loss'], gan_history['dis_loss']],
                            labels=['Generator loss', 'Discriminator loss'])

            # Overwrite checkpoint
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'Da': self.Da.state_dict(),
                    'Db': self.Db.state_dict(),
                    'Gab': self.Gab.state_dict(),
                    'Gba': self.Gba.state_dict(),
                    'd_optimizer': self.d_optimizer.state_dict(),
                    'g_optimizer': self.g_optimizer.state_dict()
                }, '%s/latest.ckpt' % (args.checkpoint_path))

            # Save loss history
            history_path = args.results_path + '/loss_history/'
            utils.mkdir([history_path])
            Gab_history.save(history_path + "Gab.pkl")
            Gba_history.save(history_path + "Gba.pkl")
            Da_history.save(history_path + "Da.pkl")
            Db_history.save(history_path + "Db.pkl")
            gan_history.save(history_path + "gan.pkl")

            # Update learning rates
            self.g_lr_scheduler.step()
            self.d_lr_scheduler.step()

            # Run one test cycle
            if args.testing:
                print('Testing')
                tst.test(args, epoch)
Example #12
0
File: FCN8s.py Project: LyazS/ATN
        if os.path.isfile(load_path):
            try:
                checkpoint = torch.load(load_path)
                net_one.load_state_dict(checkpoint['state_one'])
                mIU_benchmark = checkpoint['mIU']
                print("Load last checkpoint OK ")
                print("mIU=", mIU_benchmark)
            except:
                print("Can't Load the checkpoint QAQ")
                mIU_benchmark = 0
        else:
            EPOCH_start = 0
            print("Can't find the checkpoint ,start train from epoch 0 ...")

    his = hl.History()
    canv = hl.Canvas()
    optm_one = torch.optim.Adam(net_one.parameters(),
                                lr=base_LR,
                                weight_decay=wd)
    scheduler_lr_one = torch.optim.lr_scheduler.StepLR(optm_one,
                                                       step_size=decay_step,
                                                       gamma=decay_rate)

    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter(log_dir=os.path.join("./log"))
    # Loss = nn.NLLLoss()
    Loss = nn.CrossEntropyLoss()

    l_seg = AverageValueMeter()
    l_th = AverageValueMeter()
    l_all = AverageValueMeter()