Exemplo n.º 1
0
class TBVisualizer:
    def __init__(self, opt):
        self._opt = opt
        self._save_path = os.path.join(opt.checkpoints_dir, opt.name)

        self._log_path = os.path.join(self._save_path, 'loss_log2.txt')
        self._tb_path = os.path.join(self._save_path, 'summary.json')
        self._writer = SummaryWriter(self._save_path)

        with open(self._log_path, "a") as log_file:
            now = time.strftime("%c")
            log_file.write('================ Training Loss (%s) ================\n' % now)

    def __del__(self):
        self._writer.close()

    def display_current_results(self, visuals, it, is_train, save_visuals=False):
        for label, image_numpy in visuals.items():
            sum_name = '{}/{}'.format('Train' if is_train else 'Test', label)
            self._writer.add_image(sum_name, image_numpy, it)

            if save_visuals:
                util.save_image(image_numpy,
                                os.path.join(self._opt.checkpoints_dir, self._opt.name,
                                             'event_imgs', sum_name, '%08d.png' % it))

        self._writer.export_scalars_to_json(self._tb_path)

    def plot_scalars(self, scalars, it, is_train):
        for label, scalar in scalars.items():
            sum_name = '{}/{}'.format('Train' if is_train else 'Test', label)
            self._writer.add_scalar(sum_name, scalar, it)

    def print_current_train_errors(self, epoch, i, iters_per_epoch, errors, t, visuals_were_stored):
        log_time = time.strftime("[%d/%m/%Y %H:%M:%S]")
        visuals_info = "v" if visuals_were_stored else ""
        message = '%s (T%s, epoch: %d, it: %d/%d, t/smpl: %.3fs) ' % (log_time, visuals_info, epoch, i, iters_per_epoch, t)
        for k, v in errors.items():
            message += '%s:%.3f ' % (k, v)

        print(message)
        with open(self._log_path, "a") as log_file:
            log_file.write('%s\n' % message)

    def print_current_validate_errors(self, epoch, errors, t):
        log_time = time.strftime("[%d/%m/%Y %H:%M:%S]")
        message = '%s (V, epoch: %d, time_to_val: %ds) ' % (log_time, epoch, t)
        for k, v in errors.items():
            message += '%s:%.3f ' % (k, v)

        print(message)
        with open(self._log_path, "a") as log_file:
            log_file.write('%s\n' % message)

    def save_images(self, visuals):
        for label, image_numpy in visuals.items():
            image_name = '%s.png' % label
            save_path = os.path.join(self._save_path, "samples", image_name)
            util.save_image(image_numpy, save_path)
    def train(self, epoch_to_restore=0):
        g = Generator(self.nb_channels_first_layer, self.dim)

        if epoch_to_restore > 0:
            filename_model = os.path.join(self.dir_models, 'epoch_{}.pth'.format(epoch_to_restore))
            g.load_state_dict(torch.load(filename_model))
        else:
            g.apply(weights_init)

        g.cuda()
        g.train()

        dataset = EmbeddingsImagesDataset(self.dir_z_train, self.dir_x_train)
        dataloader = DataLoader(dataset, self.batch_size, shuffle=True, num_workers=4, pin_memory=True)
        fixed_dataloader = DataLoader(dataset, 16)
        fixed_batch = next(iter(fixed_dataloader))

        criterion = torch.nn.L1Loss()

        optimizer = optim.Adam(g.parameters())
        writer = SummaryWriter(self.dir_logs)

        try:
            epoch = epoch_to_restore
            while True:
                g.train()
                for _ in range(self.nb_epochs_to_save):
                    epoch += 1

                    for idx_batch, current_batch in enumerate(tqdm(dataloader)):
                        g.zero_grad()
                        x = Variable(current_batch['x']).type(torch.FloatTensor).cuda()
                        z = Variable(current_batch['z']).type(torch.FloatTensor).cuda()
                        g_z = g.forward(z)

                        loss = criterion(g_z, x)
                        loss.backward()
                        optimizer.step()

                    writer.add_scalar('train_loss', loss, epoch)

                z = Variable(fixed_batch['z']).type(torch.FloatTensor).cuda()
                g.eval()
                g_z = g.forward(z)
                images = make_grid(g_z.data[:16], nrow=4, normalize=True)
                writer.add_image('generations', images, epoch)
                filename = os.path.join(self.dir_models, 'epoch_{}.pth'.format(epoch))
                torch.save(g.state_dict(), filename)

        finally:
            print('[*] Closing Writer.')
            writer.close()
Exemplo n.º 3
0
class Logger:
    def __init__(self, log_dir, n_logged_samples=10, summary_writer=None):
        self._log_dir = log_dir
        print('########################')
        print('logging outputs to ', log_dir)
        print('########################')
        self._n_logged_samples = n_logged_samples
        self._summ_writer = SummaryWriter(log_dir, flush_secs=1, max_queue=1)

    def log_scalar(self, scalar, name, step_):
        self._summ_writer.add_scalar('{}'.format(name), scalar, step_)

    def log_scalars(self, scalar_dict, group_name, step, phase):
        """Will log all scalars in the same plot."""
        self._summ_writer.add_scalars('{}_{}'.format(group_name, phase),
                                      scalar_dict, step)

    def log_image(self, image, name, step):
        assert (len(image.shape) == 3)  # [C, H, W]
        self._summ_writer.add_image('{}'.format(name), image, step)

    def log_video(self, video_frames, name, step, fps=10):
        assert len(
            video_frames.shape
        ) == 5, "Need [N, T, C, H, W] input tensor for video logging!"
        print(",,,,,,,,,,", name)
        self._summ_writer.add_video('{}'.format(name),
                                    video_frames,
                                    step,
                                    fps=fps)

    def log_paths_as_videos(self,
                            paths,
                            step,
                            max_videos_to_save=2,
                            fps=10,
                            video_title='video'):

        # reshape the rollouts
        videos = [np.transpose(p['image_obs'], [0, 3, 1, 2]) for p in paths]

        # max rollout length
        max_videos_to_save = np.min([max_videos_to_save, len(videos)])
        max_length = videos[0].shape[0]
        for i in range(max_videos_to_save):
            if videos[i].shape[0] > max_length:
                max_length = videos[i].shape[0]

        # pad rollouts to all be same length
        for i in range(max_videos_to_save):
            if videos[i].shape[0] < max_length:
                padding = np.tile([videos[i][-1]],
                                  (max_length - videos[i].shape[0], 1, 1, 1))
                videos[i] = np.concatenate([videos[i], padding], 0)

        # log videos to tensorboard event file
        videos = np.stack(videos[:max_videos_to_save], 0)
        print("...... logging ", video_title)
        self.log_video(videos, video_title, step, fps=fps)

    def log_figures(self, figure, name, step, phase):
        """figure: matplotlib.pyplot figure handle"""
        assert figure.shape[
            0] > 0, "Figure logging requires input shape [batch x figures]!"
        self._summ_writer.add_figure('{}_{}'.format(name, phase), figure, step)

    def log_figure(self, figure, name, step, phase):
        """figure: matplotlib.pyplot figure handle"""
        self._summ_writer.add_figure('{}_{}'.format(name, phase), figure, step)

    def log_graph(self, array, name, step, phase):
        """figure: matplotlib.pyplot figure handle"""
        im = plot_graph(array)
        self._summ_writer.add_image('{}_{}'.format(name, phase), im, step)

    def dump_scalars(self, log_path=None):
        log_path = os.path.join(
            self._log_dir,
            "scalar_data.json") if log_path is None else log_path
        self._summ_writer.export_scalars_to_json(log_path)

    def flush(self):
        self._summ_writer.flush()
Exemplo n.º 4
0
def train(args, snapshot_path):
    base_lr = args.base_lr
    num_classes = args.num_classes
    batch_size = args.batch_size
    max_iterations = args.max_iterations

    def worker_init_fn(worker_id):
        random.seed(args.seed + worker_id)

    model = net_factory(net_type=args.model, in_chns=1, class_num=num_classes)

    db_train = BaseDataSets(base_dir=args.root_path, split="train", num=None, transform=transforms.Compose([
        RandomGenerator(args.patch_size)
    ]))

    total_slices = len(db_train)
    labeled_slice = patients_to_slices(args.root_path, args.labeled_num)
    print("Total silices is: {}, labeled slices is: {}".format(
        total_slices, labeled_slice))
    labeled_idxs = list(range(0, labeled_slice))
    unlabeled_idxs = list(range(labeled_slice, total_slices))
    batch_sampler = TwoStreamBatchSampler(
        labeled_idxs, unlabeled_idxs, batch_size, batch_size-args.labeled_bs)

    trainloader = DataLoader(db_train, batch_sampler=batch_sampler,
                             num_workers=16, pin_memory=True, worker_init_fn=worker_init_fn)

    db_val = BaseDataSets(base_dir=args.root_path, split="val")
    valloader = DataLoader(db_val, batch_size=1, shuffle=False,
                           num_workers=1)

    model.train()

    optimizer = optim.SGD(model.parameters(), lr=base_lr,
                          momentum=0.9, weight_decay=0.0001)

    ce_loss = CrossEntropyLoss()
    dice_loss = losses.DiceLoss(num_classes)

    writer = SummaryWriter(snapshot_path + '/log')
    logging.info("{} iterations per epoch".format(len(trainloader)))

    iter_num = 0
    max_epoch = max_iterations // len(trainloader) + 1
    best_performance = 0.0
    iterator = tqdm(range(max_epoch), ncols=70)
    for epoch_num in iterator:
        for i_batch, sampled_batch in enumerate(trainloader):

            volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
            unlabeled_volume_batch = volume_batch[args.labeled_bs:]

            outputs = model(volume_batch)
            outputs_soft = torch.softmax(outputs, dim=1)

            loss_ce = ce_loss(outputs[:args.labeled_bs],
                              label_batch[:][:args.labeled_bs].long())
            loss_dice = dice_loss(
                outputs_soft[:args.labeled_bs], label_batch[:args.labeled_bs].unsqueeze(1))
            supervised_loss = 0.5 * (loss_dice + loss_ce)

            consistency_weight = get_current_consistency_weight(iter_num//150)
            consistency_loss = losses.entropy_loss(outputs_soft, C=4)
            loss = supervised_loss + consistency_weight * consistency_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            lr_ = base_lr * (1.0 - iter_num / max_iterations) ** 0.9
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr_

            iter_num = iter_num + 1
            writer.add_scalar('info/lr', lr_, iter_num)
            writer.add_scalar('info/total_loss', loss, iter_num)
            writer.add_scalar('info/loss_ce', loss_ce, iter_num)
            writer.add_scalar('info/loss_dice', loss_dice, iter_num)
            writer.add_scalar('info/consistency_loss',
                              consistency_loss, iter_num)
            writer.add_scalar('info/consistency_weight',
                              consistency_weight, iter_num)

            logging.info(
                'iteration %d : loss : %f, loss_ce: %f, loss_dice: %f' %
                (iter_num, loss.item(), loss_ce.item(), loss_dice.item()))

            if iter_num % 20 == 0:
                image = volume_batch[1, 0:1, :, :]
                writer.add_image('train/Image', image, iter_num)
                outputs = torch.argmax(torch.softmax(
                    outputs, dim=1), dim=1, keepdim=True)
                writer.add_image('train/Prediction',
                                 outputs[1, ...] * 50, iter_num)
                labs = label_batch[1, ...].unsqueeze(0) * 50
                writer.add_image('train/GroundTruth', labs, iter_num)

            if iter_num > 0 and iter_num % 200 == 0:
                model.eval()
                metric_list = 0.0
                for i_batch, sampled_batch in enumerate(valloader):
                    metric_i = test_single_volume(
                        sampled_batch["image"], sampled_batch["label"], model, classes=num_classes)
                    metric_list += np.array(metric_i)
                metric_list = metric_list / len(db_val)
                for class_i in range(num_classes-1):
                    writer.add_scalar('info/val_{}_dice'.format(class_i+1),
                                      metric_list[class_i, 0], iter_num)
                    writer.add_scalar('info/val_{}_hd95'.format(class_i+1),
                                      metric_list[class_i, 1], iter_num)

                performance = np.mean(metric_list, axis=0)[0]

                mean_hd95 = np.mean(metric_list, axis=0)[1]
                writer.add_scalar('info/val_mean_dice', performance, iter_num)
                writer.add_scalar('info/val_mean_hd95', mean_hd95, iter_num)

                if performance > best_performance:
                    best_performance = performance
                    save_mode_path = os.path.join(snapshot_path,
                                                  'iter_{}_dice_{}.pth'.format(
                                                      iter_num, round(best_performance, 4)))
                    save_best = os.path.join(snapshot_path,
                                             '{}_best_model.pth'.format(args.model))
                    torch.save(model.state_dict(), save_mode_path)
                    torch.save(model.state_dict(), save_best)

                logging.info(
                    'iteration %d : mean_dice : %f mean_hd95 : %f' % (iter_num, performance, mean_hd95))
                model.train()

            if iter_num % 3000 == 0:
                save_mode_path = os.path.join(
                    snapshot_path, 'iter_' + str(iter_num) + '.pth')
                torch.save(model.state_dict(), save_mode_path)
                logging.info("save model to {}".format(save_mode_path))

            if iter_num >= max_iterations:
                break
        if iter_num >= max_iterations:
            iterator.close()
            break
    writer.close()
    return "Training Finished!"
def main(args):
    tb_dir = '/logs/tb_logs_article/fine_tuning_' + args.name
    tb = SummaryWriter(tb_dir)

    train_loader, val_loader = get_dataloaders(args)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if args.model not in ["UNET"]:
        raise Exception('Unsupported type of model, choose from [ "UNET"]')

    gen = MODEL_LOSS[args.model]['gen']
    unet = MODEL_LOSS[args.model]['unet']
    loss_func = MODEL_LOSS[args.model]['loss']

    gen = gen.to(device)
    unet = unet.to(device)

    if 'gen_path' in args and args.gen_path is not None:
        gen = load_model(gen, args.gen_path)
    if 'unet_path' in args and args.disc_path is not None:
        unet = load_model(unet, args.unet_path)

    tb.add_text(tag='gen', text_string=repr(gen))

    gen_opt = torch.optim.Adamax(gen.parameters(), lr=0.00005)
    unet_opt = torch.optim.Adamax(unet.parameters(), lr=0.00002)

    global_step = 0

    for epoch in range(args.n_epochs):
        gen.train()
        gen_step = 0
        disc_step = 0
        for x_input, y_extract, y_restor in tqdm(train_loader):
            # data reading
            #             unet.train()
            x_input = torch.FloatTensor(x_input).to(device)
            y_extract = y_extract.type(
                torch.FloatTensor).to(device).unsqueeze(1)
            y_restor = 1. - y_restor.type(
                torch.FloatTensor).to(device).unsqueeze(1)

            unet.eval()
            with torch.no_grad():
                logits_extract = unet(x_input)
                logits_extract = 1 - logits_extract.unsqueeze(1)

            logits_restore = gen.forward(logits_extract).unsqueeze(
                1)  # restoration + extraction
            # if Cleaning loss use this
            gen_loss = loss_func(1 - (logits_extract + logits_restore), None,
                                 1 - y_restor, None)
            # else if with_restore =True use this
            # input_fake = torch.cat((logits_extract + logits_restore,logits_extract),dim = 1)
            #
            # gen_loss =  loss_func(logits_extract, input_fake, y_extract, y_restor)

            gen_opt.zero_grad()
            gen_loss.backward()
            gen_opt.step()

            gen_step += 1

            if np.random.random() <= 0.5:
                disc_step += 1

            global_step += 1

            if global_step <= 1:
                continue

            tb.add_scalar('gen_vectran_loss',
                          gen_loss.item(),
                          global_step=global_step)
            #             tb.add_scalar('train_loss', loss.cpu().data.numpy(), global_step=global_step)

            if global_step % 100 == 0 or global_step <= 2:
                out_grid = torchvision.utils.make_grid(
                    1. -
                    torch.clamp(logits_extract + logits_restore, 0, 1).cpu())
                input_grid = torchvision.utils.make_grid(1. -
                                                         logits_extract.cpu())
                true_grid = torchvision.utils.make_grid(1. - y_restor.cpu())
                input_clean_grid = torchvision.utils.make_grid(x_input.cpu())

                tb.add_image(tag='train_first_input',
                             img_tensor=input_clean_grid,
                             global_step=global_step)
                tb.add_image(tag='train_out_extract',
                             img_tensor=out_grid,
                             global_step=global_step)
                tb.add_image(tag='train_input',
                             img_tensor=input_grid,
                             global_step=global_step)
                tb.add_image(tag='train_true',
                             img_tensor=true_grid,
                             global_step=global_step)
                gen.eval()

                with torch.no_grad():
                    unet.eval()
                    validate(tb,
                             val_loader,
                             unet,
                             gen,
                             loss_func,
                             global_step=global_step)

                gen.train()
                unet.train()

                save_model(gen,
                           os.path.join(tb_dir, 'gen_it_%s.pth' % global_step))
                save_model(
                    unet, os.path.join(tb_dir, 'unet_it_%s.pth' % global_step))
Exemplo n.º 6
0
class Trainer(object):
    def __init__(self, graph, optim, lrschedule, loaded_step,
                 devices, data_device,
                 dataset, hparams):
        if isinstance(hparams, str):
            hparams = JsonConfig(hparams)
        # set members
        # append date info
        date = str(datetime.datetime.now())
        date = date[:date.rfind(":")].replace("-", "")\
                                     .replace(":", "")\
                                     .replace(" ", "_")
        self.log_dir = os.path.join(hparams.Dir.log_root, "log_" + date)
        self.checkpoints_dir = os.path.join(self.log_dir, "checkpoints")
        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)
        # write hparams
        hparams.dump(self.log_dir)
        if not os.path.exists(self.checkpoints_dir):
            os.makedirs(self.checkpoints_dir)
        self.checkpoints_gap = hparams.Train.checkpoints_gap
        self.max_checkpoints = hparams.Train.max_checkpoints
        # model relative
        self.graph = graph
        self.optim = optim
        self.weight_y = hparams.Train.weight_y
        # grad operation
        self.max_grad_clip = hparams.Train.max_grad_clip
        self.max_grad_norm = hparams.Train.max_grad_norm
        # copy devices from built graph
        self.devices = devices
        self.data_device = data_device
        # number of training batches
        self.batch_size = hparams.Train.batch_size
        self.data_loader = DataLoader(dataset,
                                      batch_size=self.batch_size,
                                    #   num_workers=8,
                                      shuffle=True,
                                      drop_last=True)
        self.n_epoches = (hparams.Train.num_batches+len(self.data_loader)-1)
        self.n_epoches = self.n_epoches // len(self.data_loader)
        self.global_step = 0
        # lr schedule
        self.lrschedule = lrschedule
        self.loaded_step = loaded_step
        # data relative
        self.y_classes = hparams.Glow.y_classes
        self.y_condition = hparams.Glow.y_condition
        self.y_criterion = hparams.Criterion.y_condition
        assert self.y_criterion in ["multi-classes", "single-class"]

        # log relative
        # tensorboard
        self.writer = SummaryWriter(log_dir=self.log_dir)
        self.scalar_log_gaps = hparams.Train.scalar_log_gap
        self.plot_gaps = hparams.Train.plot_gap
        self.inference_gap = hparams.Train.inference_gap

    def train(self):
        # set to training state
        self.graph.train()
        self.global_step = self.loaded_step
        # begin to train
        for epoch in range(self.n_epoches):
            print("epoch", epoch)
            progress = tqdm(self.data_loader)
            for i_batch, batch in enumerate(progress):
                # update learning rate
                lr = self.lrschedule["func"](global_step=self.global_step,
                                             **self.lrschedule["args"])
                for param_group in self.optim.param_groups:
                    param_group['lr'] = lr
                self.optim.zero_grad()
                if self.global_step % self.scalar_log_gaps == 0:
                    self.writer.add_scalar("lr/lr", lr, self.global_step)
                # get batch data
                for k in batch:
                    batch[k] = batch[k].to(self.data_device)
                x = batch["x"]
                y = None
                y_onehot = None
                if self.y_condition:
                    if self.y_criterion == "multi-classes":
                        assert "y_onehot" in batch, "multi-classes ask for `y_onehot` (torch.FloatTensor onehot)"
                        y_onehot = batch["y_onehot"]
                    elif self.y_criterion == "single-class":
                        assert "y" in batch, "single-class ask for `y` (torch.LongTensor indexes)"
                        y = batch["y"]
                        y_onehot = thops.onehot(y, num_classes=self.y_classes)

                # at first time, initialize ActNorm
                if self.global_step == 0:
                    self.graph(x[:self.batch_size // len(self.devices), ...],
                               y_onehot[:self.batch_size // len(self.devices), ...] if y_onehot is not None else None)
                # parallel
                if len(self.devices) > 1 and not hasattr(self.graph, "module"):
                    print("[Parallel] move to {}".format(self.devices))
                    self.graph = torch.nn.parallel.DataParallel(self.graph, self.devices, self.devices[0])
                # forward phase
                z, nll, y_logits = self.graph(x=x, y_onehot=y_onehot)

                # loss
                loss_generative = Glow.loss_generative(nll)
                loss_classes = 0
                if self.y_condition:
                    loss_classes = (Glow.loss_multi_classes(y_logits, y_onehot)
                                    if self.y_criterion == "multi-classes" else
                                    Glow.loss_class(y_logits, y))
                if self.global_step % self.scalar_log_gaps == 0:
                    self.writer.add_scalar("loss/loss_generative", loss_generative, self.global_step)
                    if self.y_condition:
                        self.writer.add_scalar("loss/loss_classes", loss_classes, self.global_step)
                loss = loss_generative + loss_classes * self.weight_y

                # backward
                self.graph.zero_grad()
                self.optim.zero_grad()
                loss.backward()
                # operate grad
                if self.max_grad_clip is not None and self.max_grad_clip > 0:
                    torch.nn.utils.clip_grad_value_(self.graph.parameters(), self.max_grad_clip)
                if self.max_grad_norm is not None and self.max_grad_norm > 0:
                    grad_norm = torch.nn.utils.clip_grad_norm_(self.graph.parameters(), self.max_grad_norm)
                    if self.global_step % self.scalar_log_gaps == 0:
                        self.writer.add_scalar("grad_norm/grad_norm", grad_norm, self.global_step)
                # step
                self.optim.step()

                # checkpoints
                if self.global_step % self.checkpoints_gap == 0 and self.global_step > 0:
                    save(global_step=self.global_step,
                         graph=self.graph,
                         optim=self.optim,
                         pkg_dir=self.checkpoints_dir,
                         is_best=True,
                         max_checkpoints=self.max_checkpoints)
                if self.global_step % self.plot_gaps == 0:
                    img = self.graph(z=z, y_onehot=y_onehot, reverse=True)
                    # img = torch.clamp(img, min=0, max=1.0)
                    if self.y_condition:
                        if self.y_criterion == "multi-classes":
                            y_pred = torch.sigmoid(y_logits)
                        elif self.y_criterion == "single-class":
                            y_pred = thops.onehot(torch.argmax(F.softmax(y_logits, dim=1), dim=1, keepdim=True),
                                                  self.y_classes)
                        y_true = y_onehot
                    for bi in range(min([len(img), 4])):
                        self.writer.add_image("0_reverse/{}".format(bi), torch.cat((img[bi], batch["x"][bi]), dim=1), self.global_step)
                        if self.y_condition:
                            self.writer.add_image("1_prob/{}".format(bi), plot_prob([y_pred[bi], y_true[bi]], ["pred", "true"]), self.global_step)

                # inference
                if hasattr(self, "inference_gap"):
                    if self.global_step % self.inference_gap == 0:
                        img = self.graph(z=None, y_onehot=y_onehot, eps_std=0.5, reverse=True)
                        # img = torch.clamp(img, min=0, max=1.0)
                        for bi in range(min([len(img), 4])):
                            self.writer.add_image("2_sample/{}".format(bi), img[bi], self.global_step)

                # global step
                self.global_step += 1

        self.writer.export_scalars_to_json(os.path.join(self.log_dir, "all_scalars.json"))
        self.writer.close()
class ExperimentRunner(object):
    """
    Main Class for running experiments
    This class creates the UNET
    This class also creates the datasets
    """
    def __init__(self,
                 train_dataset_path,
                 test_dataset_path,
                 train_batch_size,
                 test_batch_size,
                 model_save_dir,
                 num_epochs=100,
                 num_data_loader_workers=10,
                 num_classes=10,
                 image_size=(512, 1024)):
        # GAN Network + VGG Loss Network
        self.model = UNetSeg()
        for param in self.model.encoder.convblock1.parameters():
            param.requires_grad = False
        for param in self.model.encoder.convblock2.parameters():
            param.requires_grad = False
        for param in self.model.encoder.convblock3.parameters():
            param.requires_grad = False
        for param in self.model.encoder.convblock4.parameters():
            param.requires_grad = False
        for param in self.model.encoder.convblock5.parameters():
            param.requires_grad = False
        # Network hyperparameters
        self.lr = 1.e-4
        self.num_classes = num_classes
        self.optimizer = torch.optim.Adam(
            [{
                'params': self.model.parameters(),
                'lr': self.lr
            }
             #{'params': self.gan.discriminator.parameters(), 'lr': self.disc_lr}
             ],
            betas=(0.5, 0.999))
        # Network losses
        self.criterion = nn.CrossEntropyLoss(reduction='none').cuda()

        # Train settings + log settings
        self.num_epochs = num_epochs
        self.log_freq = 10  # Steps
        self.test_freq = 1000  # Steps
        self.train_batch_size = train_batch_size
        self.test_batch_size = test_batch_size

        self.image_size = image_size

        # Create datasets
        self.train_dataset = CityscapesSegmentation(training=True,
                                                    image_size=self.image_size)
        self.test_dataset = CityscapesSegmentation(split="val",
                                                   training=False,
                                                   image_size=self.image_size)

        self.train_dataset_loader = DataLoader(
            self.train_dataset,
            batch_size=self.train_batch_size,
            shuffle=True,
            num_workers=num_data_loader_workers)
        self.test_dataset_loader = DataLoader(self.test_dataset,
                                              batch_size=self.test_batch_size,
                                              shuffle=True,
                                              num_workers=self.test_batch_size)

        # Use the GPU if it's available.
        self.cuda = torch.cuda.is_available()

        if self.cuda:
            self.model.cuda()

        # Tensorboard logger
        self.txwriter = SummaryWriter()
        self.model_save_dir = model_save_dir
        self.save_freq = 5000
        self.display_freq = 500
        self.test_display_freq = 300
        self.edge_lambda = 1.

        #self.model = nn.DataParallel(self.model)

    def _optimize(self, y_pred, y_gt, y_edges):
        """
        VGGLoss + GAN loss
        """
        self.optimizer.zero_grad()
        loss = self.criterion(y_pred, y_gt)
        #pdb.set_trace()
        loss += loss * self.edge_lambda * y_edges
        loss = loss.mean()
        loss.backward()
        self.optimizer.step()
        return loss

    def _adjust_learning_rate(self, epoch):
        """
        TODO
        """
        for param_group in self.optimizer.param_groups:
            param_group['lr'] /= divide_lr
        return

    def _clip_weights(self):
        """
        TODO
        """
        raise NotImplementedError()

    def test(self, epoch):
        num_batches = len(self.test_dataset_loader)
        test_accuracies = AverageMeter()
        test_bin_accuracies = AverageMeter()
        test_multi_accuracies = AverageMeter()
        for batch_id, batch_data in enumerate(self.test_dataset_loader):
            current_step = batch_id
            # Set to eval()
            self.model.eval()
            # Get data from dataset
            im = batch_data['im'].cuda(async=True)
            gt = batch_data['gt'].cuda(async=True)
            gt = gt.view(gt.shape[0], gt.shape[2], gt.shape[3])

            # ============
            # Make prediction
            pred = self.model(im)
            pred_ans = torch.argmax(F.softmax(pred, dim=1), dim=1)
            bin_pred_ans = pred_ans != torch.zeros_like(pred_ans).cuda(
                async=True)
            bin_pred_gt = gt != torch.zeros_like(pred_ans).cuda(async=True)

            acc = 100.0 * torch.mean((pred_ans == gt).float())
            change_detected = bin_pred_ans == bin_pred_gt
            bin_acc = 100.0 * torch.mean(change_detected.float())
            # Logic: Multi class accuracy is union of change detected (pred_ans > 0) and if change was correctly determined
            change_detected[bin_pred_ans == 0] = 0
            multi_acc = 100.0 * torch.mean(
                (pred_ans[change_detected] == gt[change_detected]).float())

            test_accuracies.update(acc.item(), gt.shape[0])
            test_bin_accuracies.update(bin_acc.item(), gt.shape[0])
            if torch.sum(change_detected) > 0:
                test_multi_accuracies.update(multi_acc.item(), gt.shape[0])
            # ============
            # Print and Plot
            print(
                "TEST: Step: {}, Batch {}/{} has acc {}, multi acc {}, and binary acc {}"
                .format(current_step, batch_id, num_batches,
                        test_accuracies.avg, test_multi_accuracies.avg,
                        test_bin_accuracies.avg))
            if current_step % self.test_display_freq == 0:
                im = im[0, :, :, :].cpu()
                name = '{0}_{1}_{2}'.format(epoch, current_step, "image")
                #pdb.set_trace()
                mask = labelVisualize(pred_ans[0, :, :].detach().cpu().numpy(),
                                      self.num_classes, self.image_size)
                gt_label = labelVisualize(gt[0, :, :].detach().cpu().numpy(),
                                          self.num_classes, self.image_size)
                #pdb.set_trace()
                combined = visualizeAllCityImages(
                    im,
                    transforms.ToTensor()(gt_label),
                    transforms.ToTensor()(mask))
                self.txwriter.add_image("Test/" + name,
                                        transforms.ToTensor()(combined))
        return test_accuracies.avg, test_bin_accuracies.avg, test_multi_accuracies.avg

    def train(self):
        """
        Main training loop
        """

        for epoch in range(self.num_epochs):
            num_batches = len(self.train_dataset_loader)
            # Initialize running averages
            train_accuracies = AverageMeter()
            train_bin_accuracies = AverageMeter()
            train_multi_accuracies = AverageMeter()
            train_losses = AverageMeter()

            for batch_id, batch_data in enumerate(self.train_dataset_loader):
                self.model.train()  # Set the model to train mode
                current_step = epoch * num_batches + batch_id
                # Get data from dataset
                im = batch_data['im'].cuda(async=True)
                gt = batch_data['gt'].cuda(async=True)
                edges = batch_data['edges'].cuda(async=True)
                gt = gt.view(gt.shape[0], gt.shape[2], gt.shape[3])

                # ============
                # Make prediction
                pred = self.model(im)
                loss = self._optimize(pred, gt, edges)
                train_losses.update(loss.item(), gt.shape[0])
                #print(torch.argmax( F.softmax(pred, dim=1 ), dim=1 ).shape)
                #print(gt.shape)
                pred_ans = torch.argmax(F.softmax(pred, dim=1), dim=1)
                bin_pred_ans = pred_ans != torch.zeros_like(pred_ans).cuda(
                    async=True)
                bin_pred_gt = gt != torch.zeros_like(pred_ans).cuda(async=True)

                acc = 100.0 * torch.mean((pred_ans == gt).float())
                change_detected = bin_pred_ans == bin_pred_gt
                bin_acc = 100.0 * torch.mean(change_detected.float())
                # Logic: Multi class accuracy is union of change detected (pred_ans > 0) and if change was correctly determined
                change_detected[bin_pred_ans == 0] = 0
                #pdb.set_trace()
                multi_acc = 100.0 * torch.mean(
                    (pred_ans[change_detected] == gt[change_detected]).float())

                train_accuracies.update(acc.item(), gt.shape[0])
                train_bin_accuracies.update(bin_acc.item(), gt.shape[0])
                if torch.sum(change_detected) > 0:
                    train_multi_accuracies.update(multi_acc.item(),
                                                  gt.shape[0])
                # ============
                # Not adjusting learning rate currently
                # if epoch % 100 == 99:
                #     self._adjust_learning_rate(epoch)
                # # Not Clipping Weights
                # self._clip_weights()

                if current_step % self.log_freq == 0:
                    print(
                        "Step: {}, Epoch: {}, Batch {}/{} has loss {}, acc {}, multi acc {}, and binary acc {}"
                        .format(current_step, epoch, batch_id, num_batches,
                                train_losses.avg, train_accuracies.avg,
                                train_multi_accuracies.avg,
                                train_bin_accuracies.avg))
                    self.txwriter.add_scalar('train/loss', train_losses.avg,
                                             current_step)
                    self.txwriter.add_scalar('train/accuracy',
                                             train_accuracies.avg,
                                             current_step)
                    self.txwriter.add_scalar('train/multi_class_accuracy',
                                             train_multi_accuracies.avg,
                                             current_step)
                    self.txwriter.add_scalar('train/binary_accuracy',
                                             train_bin_accuracies.avg,
                                             current_step)
                """
                Visualize some images
                """
                if current_step % self.display_freq == 0:
                    im = im[0, :, :, :].cpu()
                    name = '{0}_{1}_{2}'.format(epoch, current_step, "image")
                    #pdb.set_trace()
                    mask = labelVisualize(
                        pred_ans[0, :, :].detach().cpu().numpy(),
                        self.num_classes, self.image_size)
                    gt_label = labelVisualize(
                        gt[0, :, :].detach().cpu().numpy(), self.num_classes,
                        self.image_size)
                    #pdb.set_trace()
                    combined = visualizeAllCityImages(
                        im,
                        transforms.ToTensor()(gt_label),
                        transforms.ToTensor()(mask))
                    self.txwriter.add_image("Train/" + name,
                                            transforms.ToTensor()(combined))

                # Test accuracies
                if current_step % self.test_freq == 0:
                    self.model.eval()
                    test_accuracy, test_bin_accuracy, test_multi_accuracy = self.test(
                        epoch)
                    print("Epoch: {} has val accuracy {}".format(
                        epoch, test_accuracy))
                    self.txwriter.add_scalar('test/accuracy', test_accuracy,
                                             current_step)
                    self.txwriter.add_scalar('test/multi_class_accuracy',
                                             test_multi_accuracy, current_step)
                    self.txwriter.add_scalar('test/binary_accuracy',
                                             test_bin_accuracy, current_step)
                """
                Save Model periodically
                """
                if (current_step % self.save_freq == 0) and current_step > 0:
                    save_name1 = 'unet_encoder_checkpoint.pth'
                    save_name2 = 'unet_decoder_checkpoint.pth'
                    torch.save(self.model.encoder.state_dict(), save_name1)
                    torch.save(self.model.decoder.state_dict(), save_name2)
                    #torch.save(self.model.state_dict(), save_name)
                    print('Saved model to {}'.format(save_name1))
                    print('Saved model to {}'.format(save_name2))
Exemplo n.º 8
0
def main():
    writer = SummaryWriter(args.snapshot_dir)
    
    if not args.gpu == 'None':
        os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True

    xlsor = XLSor(num_classes=args.num_classes)
    print(xlsor)

    saved_state_dict = torch.load(args.restore_from)
    new_params = xlsor.state_dict().copy()
    for i in saved_state_dict:
        i_parts = i.split('.')
        if not i_parts[0]=='fc':
            new_params['.'.join(i_parts[0:])] = saved_state_dict[i] 
    
    xlsor.load_state_dict(new_params)


    model = DataParallelModel(xlsor)
    model.train()
    model.float()
    model.cuda()    

    criterion = Criterion()
    criterion = DataParallelCriterion(criterion)
    criterion.cuda()
    
    cudnn.benchmark = True

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)


    trainloader = data.DataLoader(XRAYDataSet(args.data_dir, args.data_list, max_iters=args.num_steps*args.batch_size, crop_size=input_size,
                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN), 
                    batch_size=args.batch_size, shuffle=True, num_workers=16, pin_memory=True)

    optimizer = optim.SGD([{'params': filter(lambda p: p.requires_grad, xlsor.parameters()), 'lr': args.learning_rate }],
                lr=args.learning_rate, momentum=args.momentum,weight_decay=args.weight_decay)
    optimizer.zero_grad()

    interp = nn.Upsample(size=input_size, mode='bilinear', align_corners=True)


    for i_iter, batch in enumerate(trainloader):
        i_iter += args.start_iters
        images, labels, _, _ = batch
        images = images.cuda()
        labels = labels.float().cuda()
        if torch_ver == "0.3":
            images = Variable(images)
            labels = Variable(labels)

        optimizer.zero_grad()
        lr = adjust_learning_rate(optimizer, i_iter)
        preds = model(images, args.recurrence)

        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()

        if i_iter % 100 == 0:
            writer.add_scalar('learning_rate', lr, i_iter)
            writer.add_scalar('loss', loss.data.cpu().numpy(), i_iter)

        if i_iter % 100 == 0:
            images_inv = inv_preprocess(images, args.save_num_images, IMG_MEAN)
            if isinstance(preds, list):
                preds = preds[0]
            if isinstance(preds, list):
                preds = preds[0]
            preds = interp(preds)
            for index, img in enumerate(images_inv):
                writer.add_image('Images/'+str(index), torch.from_numpy(img/255.).permute(2,0,1), i_iter)
                writer.add_image('Labels/'+str(index), labels[index], i_iter)
                writer.add_image('preds/'+str(index), (preds[index]>0.5).float(), i_iter)

        print('iter = {} of {} completed, loss = {}'.format(i_iter, args.num_steps, loss.data.cpu().numpy()))

        if i_iter >= args.num_steps-1:
            print('save model ...')
            torch.save(xlsor.state_dict(),osp.join(args.snapshot_dir, 'XLSor_'+str(args.num_steps)+'.pth'))
            break

        if i_iter % args.save_pred_every == 0:
            print('taking snapshot ...')
            torch.save(xlsor.state_dict(),osp.join(args.snapshot_dir, 'XLSor_'+str(i_iter)+'.pth'))

    end = timeit.default_timer()
    print(end-start,'seconds')
Exemplo n.º 9
0
            iter_loss += loss
            # Update the weights once in p['nAveGrad'] forward passes
            if aveGrad % p['nAveGrad'] == 0:
                writer.add_scalar('data/total_loss_iter', loss.item(),
                                  ii + num_img_tr * epoch)
                loss_meter.add(iter_loss.cpu())
                optimizer.step()
                optimizer.zero_grad()
                aveGrad = 0
                iter_loss = 0.0

            if ii % (num_img_tr / 20) == 0:
                grid_image = make_grid(inputs[:3].clone().cpu().data,
                                       3,
                                       normalize=True)
                writer.add_image('Image', grid_image, global_step)
                grid_image = make_grid(utils.decode_seg_map_sequence(
                    torch.max(outputs[:3], 1)[1].detach().cpu().numpy()),
                                       3,
                                       normalize=False,
                                       range=(0, 255))
                writer.add_image('Predicted label', grid_image, global_step)
                grid_image = make_grid(utils.decode_seg_map_sequence(
                    torch.squeeze(labels[:3], 1).detach().cpu().numpy()),
                                       3,
                                       normalize=False,
                                       range=(0, 255))
                writer.add_image('Groundtruth label', grid_image, global_step)
            #break
        # Save the model
        if (epoch % snapshot) == snapshot - 1:
Exemplo n.º 10
0
class TensorboardHook(LoggingHook):
    """Hook for logging training process to tensorboard.

    Args:
        log_path (str): path to directory in which log files will be stored.
        metrics (list): metrics to log; each metric has to be a subclass of spk.Metric.
        log_train_loss (bool, optional): enable logging of training loss.
        log_validation_loss (bool, optional): enable logging of validation loss.
        log_learning_rate (bool, optional): enable logging of current learning rate.
        every_n_epochs (int, optional): epochs after which logging takes place.
        img_every_n_epochs (int, optional):
        log_histogram (bool, optional):

    """
    def __init__(
        self,
        log_path,
        metrics,
        log_train_loss=True,
        log_validation_loss=True,
        log_learning_rate=True,
        every_n_epochs=1,
        img_every_n_epochs=10,
        log_histogram=False,
    ):
        from tensorboardX import SummaryWriter

        super(TensorboardHook,
              self).__init__(log_path, metrics, log_train_loss,
                             log_validation_loss, log_learning_rate)
        self.writer = SummaryWriter(self.log_path)
        self.every_n_epochs = every_n_epochs
        self.log_histogram = log_histogram
        self.img_every_n_epochs = img_every_n_epochs

    def on_epoch_end(self, trainer):
        if trainer.epoch % self.every_n_epochs == 0:
            if self.log_train_loss:
                self.writer.add_scalar("train/loss",
                                       self._train_loss / self._counter,
                                       trainer.epoch)
            if self.log_learning_rate:
                self.writer.add_scalar(
                    "train/learning_rate",
                    trainer.optimizer.param_groups[0]["lr"],
                    trainer.epoch,
                )

    def on_validation_end(self, trainer, val_loss):
        if trainer.epoch % self.every_n_epochs == 0:
            for metric in self.metrics:
                m = metric.aggregate()

                if np.isscalar(m):
                    self.writer.add_scalar("metrics/%s" % metric.name,
                                           float(m), trainer.epoch)
                elif m.ndim == 2:
                    if trainer.epoch % self.img_every_n_epochs == 0:
                        import matplotlib.pyplot as plt

                        # tensorboardX only accepts images as numpy arrays.
                        # we therefore convert plots in numpy array
                        # see https://github.com/lanpa/tensorboard-pytorch/blob/master/examples/matplotlib_demo.py
                        fig = plt.figure()
                        plt.colorbar(plt.pcolor(m))
                        fig.canvas.draw()

                        np_image = np.fromstring(fig.canvas.tostring_rgb(),
                                                 dtype="uint8")
                        np_image = np_image.reshape(
                            fig.canvas.get_width_height()[::-1] + (3, ))

                        plt.close(fig)

                        self.writer.add_image("metrics/%s" % metric.name,
                                              np_image, trainer.epoch)

            if self.log_validation_loss:
                self.writer.add_scalar("train/val_loss", float(val_loss),
                                       trainer.step)

            if self.log_histogram:
                for name, param in trainer._model.named_parameters():
                    self.writer.add_histogram(name,
                                              param.detach().cpu().numpy(),
                                              trainer.epoch)

    def on_train_ends(self, trainer):
        self.writer.close()

    def on_train_failed(self, trainer):
        self.writer.close()
Exemplo n.º 11
0
                    "cls": lsInfo["cls"] / cfg.train.batchSize,
                    #"L2":l2,
                    "l2": l2 * cfg.loss.l2,
                    "lr": lr
                },
                niter)

            #add graph only once
            if addGraphFlag:
                writer.add_graph(network, imgs)
                addGraphFlag = False

            # add img only once
            if addImgFlag:
                x = vutils.make_grid(images, normalize=True, scale_each=False)
                writer.add_image('Image/origin', x, startEpoch)

                x = vutils.make_grid(imgs, normalize=True, scale_each=False)
                writer.add_image('Image/normalized/', x, startEpoch)

                addImgFlag = False

            #iadd config to text
            if addCofigFlag:
                txt = ""
                for key, value in cfg.items():
                    for k, v in value.items():
                        txt += str(key) + " " + str(k) + ": " + str(
                            v) + "     \n"
                # print(txt)
                writer.add_text("config", txt)
Exemplo n.º 12
0
def train(args):
    # Load data
    TrainDataset = SyntheticDataset(data_path=args.data_path,
                                    mode=args.mode,
                                    img_h=args.img_h,
                                    img_w=args.img_w,
                                    patch_size=args.patch_size,
                                    do_augment=args.do_augment)
    train_loader = DataLoader(TrainDataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
    print('===> Train: There are totally {} training files'.format(len(TrainDataset)))

    net = HomographyModel(args.use_batch_norm)
    if args.resume:
        model_path = os.path.join(args.model_dir, args.model_name)
        ckpt = torch.load(model_path)
        net.load_state_dict(ckpt.state_dict())
    if torch.cuda.is_available():
        net = net.cuda()

    optimizer = optim.Adam(net.parameters(), lr=args.lr)  # default as 0.0001
    decay_rate = 0.96
    step_size = (math.log(decay_rate) * args.max_epochs) / math.log(args.min_lr * 1.0 / args.lr)
    print('args lr:', args.lr, args.min_lr)
    print('===> Decay steps:', step_size)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=int(step_size), gamma=0.96)

    print("start training")
    writer = SummaryWriter(logdir=args.log_dir, flush_secs=60)
    score_print_fre = 100
    summary_fre = 1000
    model_save_fre = 4000
    glob_iter = 0
    t0 = time.time()

    for epoch in range(args.max_epochs):
        net.train()
        epoch_start = time.time()
        train_l1_loss = 0.0
        train_l1_smooth_loss = 0.0
        train_h_loss = 0.0

        for i, batch_value in enumerate(train_loader):
            I1_batch = batch_value[0].float()
            I2_batch = batch_value[1].float()
            I1_aug_batch = batch_value[2].float()
            I2_aug_batch = batch_value[3].float()
            I_batch = batch_value[4].float()
            I_prime_batch = batch_value[5].float()
            pts1_batch = batch_value[6].float()
            gt_batch = batch_value[7].float()
            patch_indices_batch = batch_value[8].float()

            if torch.cuda.is_available():
                I1_aug_batch = I1_aug_batch.cuda()
                I2_aug_batch = I2_aug_batch.cuda()
                I_batch = I_batch.cuda()
                pts1_batch = pts1_batch.cuda()
                gt_batch = gt_batch.cuda()
                patch_indices_batch = patch_indices_batch.cuda()

            # forward, backward, update weights
            optimizer.zero_grad()
            batch_out = net(I1_aug_batch, I2_aug_batch, I_batch, pts1_batch, gt_batch, patch_indices_batch)
            h_loss = batch_out['h_loss']
            rec_loss = batch_out['rec_loss']
            ssim_loss = batch_out['ssim_loss']
            l1_loss = batch_out['l1_loss']
            l1_smooth_loss = batch_out['l1_smooth_loss']
            ncc_loss = batch_out['ncc_loss']
            pred_I2 = batch_out['pred_I2']

            loss = l1_loss
            loss.backward()
            optimizer.step()

            train_l1_loss += loss.item()
            train_l1_smooth_loss += l1_smooth_loss.item()
            train_h_loss += h_loss.item()
            if (i + 1) % score_print_fre == 0 or (i + 1) == len(train_loader):
                print(
                    "Training: Epoch[{:0>3}/{:0>3}] Iter[{:0>3}]/[{:0>3}] l1 loss: {:.4f} "
                    "l1 smooth loss: {:.4f} h loss: {:.4f} lr={:.8f}".format(
                        epoch + 1, args.max_epochs, i + 1, len(train_loader), train_l1_loss / score_print_fre,
                        train_l1_smooth_loss / score_print_fre, train_h_loss / score_print_fre, scheduler.get_lr()[0]))
                train_l1_loss = 0.0
                train_l1_smooth_loss = 0.0
                train_h_loss = 0.0

            if glob_iter % summary_fre == 0:
                writer.add_scalar('learning_rate', scheduler.get_lr()[0], glob_iter)
                writer.add_scalar('h_loss', h_loss, glob_iter)
                writer.add_scalar('rec_loss', rec_loss, glob_iter)
                writer.add_scalar('ssim_loss', ssim_loss, glob_iter)
                writer.add_scalar('l1_loss', l1_loss, glob_iter)
                writer.add_scalar('l1_smooth_loss', l1_smooth_loss, glob_iter)
                writer.add_scalar('ncc_loss', ncc_loss, glob_iter)

                writer.add_image('I', utils.denorm_img(I_batch[0, ...].cpu().numpy()).astype(np.uint8)[:, :, ::-1],
                                 glob_iter, dataformats='HWC')
                writer.add_image('I_prime',
                                 utils.denorm_img(I_prime_batch[0, ...].numpy()).astype(np.uint8)[:, :, ::-1],
                                 glob_iter, dataformats='HWC')

                writer.add_image('I1_aug', utils.denorm_img(I1_aug_batch[0, 0, ...].cpu().numpy()).astype(np.uint8),
                                 glob_iter, dataformats='HW')
                writer.add_image('I2_aug', utils.denorm_img(I2_aug_batch[0, 0, ...].cpu().numpy()).astype(np.uint8),
                                 glob_iter, dataformats='HW')
                writer.add_image('pred_I2',
                                 utils.denorm_img(pred_I2[0, 0, ...].cpu().detach().numpy()).astype(np.uint8),
                                 glob_iter, dataformats='HW')

                writer.add_image('I2', utils.denorm_img(I2_batch[0, 0, ...].numpy()).astype(np.uint8), glob_iter,
                                 dataformats='HW')
                writer.add_image('I1', utils.denorm_img(I1_batch[0, 0, ...].numpy()).astype(np.uint8), glob_iter,
                                 dataformats='HW')

            # save model
            if glob_iter % model_save_fre == 0 and glob_iter != 0:
                filename = 'model' + '_iter_' + str(glob_iter) + '.pth'
                model_save_path = os.path.join(args.model_dir, filename)
                torch.save(net, model_save_path)

            glob_iter += 1
        scheduler.step()
        print("Epoch: {} epoch time: {:.1f}s".format(epoch, time.time() - epoch_start))

    elapsed_time = time.time() - t0
    print("Finished Training in {:.0f}h {:.0f}m {:.0f}s.".format(
        elapsed_time // 3600, (elapsed_time % 3600) // 60, (elapsed_time % 3600) % 60))
Exemplo n.º 13
0
class Trainer(object):
    def __init__(self,
                 generator,
                 discriminator,
                 optimizer_generator,
                 optimizer_discriminator,
                 losses_list,
                 metrics_list=None,
                 device=torch.device("cuda:0"),
                 ndiscriminator=-1,
                 batch_size=128,
                 sample_size=8,
                 epochs=5,
                 checkpoints="./model/gan",
                 retain_checkpoints=5,
                 recon="./images",
                 test_noise=None,
                 log_tensorboard=True,
                 **kwargs):
        self.device = device
        self.generator = generator.to(self.device)
        self.discriminator = discriminator.to(self.device)
        if "optimizer_generator_options" in kwargs:
            self.optimizer_generator = optimizer_generator(
                self.generator.parameters(),
                **kwargs["optimizer_generator_options"])
        else:
            self.optimizer_generator = optimizer_generator(
                self.generator.parameters())
        if "optimizer_discriminator_options" in kwargs:
            self.optimizer_discriminator = optimizer_discriminator(
                self.discriminator.parameters(),
                **kwargs["optimizer_discriminator_options"])
        else:
            self.optimizer_discriminator = optimizer_discriminator(
                self.discriminator.parameters())
        self.losses = {}
        self.loss_logs = {}
        for loss in losses_list:
            name = type(loss).__name__
            self.loss_logs[name] = []
            self.losses[name] = loss
        if metrics_list is None:
            self.metrics = None
            self.metric_logs = None
        else:
            self.metric_logs = {}
            self.metrics = {}
            for metric in metrics_list:
                name = type(metric).__name__
                self.metric_logs[name] = []
                self.metrics[name] = metric
        self.batch_size = batch_size
        self.sample_size = sample_size
        self.epochs = epochs
        self.checkpoints = checkpoints
        self.retain_checkpoints = retain_checkpoints
        self.recon = recon
        self.test_noise = torch.randn(
            self.sample_size, self.generator.encoding_dims,
            device=self.device) if test_noise is None else test_noise
        # Not needed but we need to store this to avoid errors. Also makes life simpler
        self.noise = torch.randn(1)
        self.real_inputs = torch.randn(1)
        self.labels = torch.randn(1)

        self.loss_information = {
            'generator_losses': 0.0,
            'discriminator_losses': 0.0,
            'generator_iters': 0,
            'discriminator_iters': 0,
        }
        self.ndiscriminator = ndiscriminator
        if "loss_information" in kwargs:
            self.loss_information.update(kwargs["loss_information"])
        if "loss_logs" in kwargs:
            self.loss_logs.update(kwargs["loss_logs"])
        if "metric_logs" in kwargs:
            self.metric_logs.update(kwargs["metric_logs"])
        self.start_epoch = 0
        self.last_retained_checkpoint = 0
        self.writer = SummaryWriter()
        self.log_tensorboard = log_tensorboard
        if self.log_tensorboard:
            self.tensorboard_information = {
                "step": 0,
                "repeat_step": 4,
                "repeats": 1
            }
        self.nrow = kwargs["display_rows"] if "display_rows" in kwargs else 8
        self.labels_provided = kwargs["labels_provided"] if "labels_provided" in kwargs\
                                        else False

    def save_model_extras(self, save_path):
        return {}

    def save_model(self, epoch):
        if self.last_retained_checkpoint == self.retain_checkpoints:
            self.last_retained_checkpoint = 0
        save_path = self.checkpoints + str(
            self.last_retained_checkpoint) + '.model'
        self.last_retained_checkpoint += 1
        print("Saving Model at '{}'".format(save_path))
        model = {
            'epoch': epoch + 1,
            'generator': self.generator.state_dict(),
            'discriminator': self.discriminator.state_dict(),
            'optimizer_generator': self.optimizer_generator.state_dict(),
            'optimizer_discriminator':
            self.optimizer_discriminator.state_dict(),
            'loss_information': self.loss_information,
            'loss_objects': self.losses,
            'metric_objects': self.metrics,
            'loss_logs': self.loss_logs,
            'metric_logs': self.metric_logs
        }
        # FIXME(avik-pal): Not a very good function name
        model.update(self.save_model_extras(save_path))
        torch.save(model, save_path)

    def load_model_extras(self, load_path):
        pass

    def load_model(self, load_path=""):
        if load_path == "":
            load_path = self.checkpoints + str(
                self.last_retained_checkpoint) + '.model'
        print("Loading Model From '{}'".format(load_path))
        try:
            check = torch.load(load_path)
            self.start_epoch = check['epoch']
            self.losses = check['loss_objects']
            self.metrics = check['metric_objects']
            self.loss_information = check['loss_information']
            self.loss_logs = check['loss_logs']
            self.metric_logs = check['metric_logs']
            self.generator.load_state_dict(check['generator'])
            self.discriminator.load_state_dict(check['discriminator'])
            self.optimizer_generator.load_state_dict(
                check['optimizer_generator'])
            self.optimizer_discriminator.load_state_dict(
                check['optimizer_discriminator'])
            # FIXME(avik-pal): Not a very good function name
            self.load_model_extras(check)
        except:
            warn("Model could not be loaded from {}. Training from Scratch".
                 format(load_path))
            self.start_epoch = 0
            self.generator_losses = []
            self.discriminator_losses = []

    # TODO(avik-pal): The _get_step will fail in a lot of cases
    def _get_step(self, update=True):
        if not update:
            return self.tensorboard_information["step"]
        if self.tensorboard_information[
                "repeats"] < self.tensorboard_information["repeat_step"]:
            self.tensorboard_information["repeats"] += 1
            return self.tensorboard_information["step"]
        else:
            self.tensorboard_information["step"] += 1
            self.tensorboard_information["repeats"] = 1
            return self.tensorboard_information["step"]

    def sample_images(self, epoch):
        save_path = "{}/epoch{}.png".format(self.recon, epoch + 1)
        print("Generating and Saving Images to {}".format(save_path))
        with torch.no_grad():
            images = self.generator(self.test_noise.to(self.device))
            img = torchvision.utils.make_grid(images)
            torchvision.utils.save_image(img, save_path, nrow=self.nrow)
            if self.log_tensorboard:
                self.writer.add_image("Generated Samples", img,
                                      self._get_step(False))

    def train_logger(self, epoch, running_losses):
        print('Epoch {} Summary: '.format(epoch + 1))
        for name, val in running_losses.items():
            print('Mean {} : {}'.format(name, val))

    def tensorboard_log_losses(self):
        if self.log_tensorboard:
            running_generator_loss = self.loss_information["generator_losses"] /\
                self.loss_information["generator_iters"]
            running_discriminator_loss = self.loss_information["discriminator_losses"] /\
                self.loss_information["discriminator_iters"]
            self.writer.add_scalar("Running Discriminator Loss",
                                   running_discriminator_loss,
                                   self._get_step())
            self.writer.add_scalar("Running Generator Loss",
                                   running_generator_loss, self._get_step())
            self.writer.add_scalars(
                "Running Losses", {
                    "Running Discriminator Loss": running_discriminator_loss,
                    "Running Generator Loss": running_generator_loss
                }, self._get_step())

    def tensorboard_log_metrics(self):
        if self.tensorboard_log:
            for name, value in self.loss_logs.items():
                if type(value) is tuple:
                    self.writer.add_scalar('Losses/{}-Generator'.format(name),
                                           value[0], self._get_step(False))
                    self.writer.add_scalar(
                        'Losses/{}-Discriminator'.format(name), value[1],
                        self._get_step(False))
                else:
                    self.writer.add_scalar('Losses/{}'.format(name), value,
                                           self._get_step(False))
            if self.metric_logs:
                for name, value in self.metric_logs.items():
                    # FIXME(Aniket1998): Metrics step should be number of epochs so far
                    self.writer.add_scalar("Metrics/{}".format(name), value,
                                           self._get_step(False))

    def _get_argument_maps(self, loss):
        sig = signature(loss.train_ops)
        args = list(sig.parameters.keys())
        for arg in args:
            if arg not in self.__dict__:
                raise Exception(
                    "Argument : %s needed for %s not present".format(
                        arg,
                        type(loss).__name__))
        return args

    def _store_loss_maps(self):
        self.loss_arg_maps = {}
        for name, loss in self.losses.items():
            self.loss_arg_maps[name] = self._get_argument_maps(loss)

    def train_stopper(self):
        if self.ndiscriminator == -1:
            return False
        else:
            return self.loss_information[
                "discriminator_iters"] % self.ndiscriminator != 0

    def train_iter_custom(self):
        pass

    # TODO(avik-pal): Clean up this function and avoid returning values
    def train_iter(self):
        self.train_iter_custom()
        ldis, lgen, dis_iter, gen_iter = 0.0, 0.0, 0, 0
        for name, loss in self.losses.items():
            if isinstance(loss, GeneratorLoss) and isinstance(
                    loss, DiscriminatorLoss):
                cur_loss = loss.train_ops(*itemgetter(
                    *self.loss_arg_maps[name])(self.__dict__))
                self.loss_logs[name].append(cur_loss)
                if type(cur_loss) is tuple:
                    lgen, ldis, gen_iter, dis_iter = lgen + cur_loss[0], ldis + cur_loss[1],\
                        gen_iter + 1, dis_iter + 1
            elif isinstance(loss, GeneratorLoss):
                if self.ndiscriminator == -1 or\
                   self.loss_information["discriminator_iters"] % self.ncritic == 0:
                    cur_loss = loss.train_ops(*itemgetter(
                        *self.loss_arg_maps[name])(self.__dict__))
                    self.loss_logs[name].append(cur_loss)
                    lgen, gen_iter = lgen + cur_loss, gen_iter + 1
            elif isinstance(loss, DiscriminatorLoss):
                cur_loss = loss.train_ops(*itemgetter(
                    *self.loss_arg_maps[name])(self.__dict__))
                self.loss_logs[name].append(cur_loss)
                ldis, dis_iter = ldis + cur_loss, dis_iter + 1
        return lgen, ldis, gen_iter, dis_iter

    def log_metrics(self, epoch):
        if not self.metric_logs:
            warn('No evaluation metric logs present')
        else:
            for name, val in self.metric_logs.item():
                print('{} : {}'.format(name, val))
            self.tensorboard_log_metrics()

    def eval_ops(self, epoch, **kwargs):
        self.sample_images(epoch)
        if self.metrics is not None:
            for name, metric in self.metrics.items():
                if name + '_inputs' not in kwargs:
                    raise Exception(
                        "Inputs not provided for metric {}".format(name))
                else:
                    self.metric_logs[name].append(
                        metric.metric_ops(self.generator, self.discriminator,
                                          kwargs[name + '_inputs']))
                    self.log_metrics(self, epoch)

    def train(self, data_loader, **kwargs):
        self.generator.train()
        self.discriminator.train()

        for epoch in range(self.start_epoch, self.epochs):
            self.generator.train()
            self.discriminator.train()
            for data in data_loader:
                if type(data) is tuple:
                    if not data[0].size()[0] == self.batch_size:
                        continue
                    self.real_inputs = data[0].to(self.device)
                    self.labels = data[1].to(self.device)
                else:
                    if not data.size()[0] == self.batch_size:
                        continue
                    self.real_inputs = data[0].to(self.device)

                self.noise = torch.randn(self.batch_size,
                                         self.generator.encoding_dims,
                                         device=self.device)

                lgen, ldis, gen_iter, dis_iter = self.train_iter()
                self.loss_information['generator_losses'] += lgen
                self.loss_information['discriminator_losses'] += ldis
                self.loss_information['generator_iters'] += gen_iter
                self.loss_information['discriminator_iters'] += dis_iter

                self.tensorboard_log_losses()

                if self.train_stopper():
                    break

            self.save_model(epoch)
            self.train_logger(
                epoch, {
                    'Generator Loss':
                    self.loss_information['generator_losses'] /
                    self.loss_information['generator_iters'],
                    'Discriminator Loss':
                    self.loss_information['discriminator_losses'] /
                    self.loss_information['discriminator_iters']
                })
            self.generator.eval()
            self.discriminator.eval()
            self.eval_ops(epoch, **kwargs)

        print("Training of the Model is Complete")

    def __call__(self, data_loader, **kwargs):
        self._store_loss_maps()
        self.train(data_loader, **kwargs)
        self.writer.close()
Exemplo n.º 14
0
    dummy_s2 = torch.rand(1)
    # data grouping by `slash`
    writer.add_scalar('data/scalar1', dummy_s1[0], n_iter)
    writer.add_scalar('data/scalar2', dummy_s2[0], n_iter)

    writer.add_scalars(
        'data/scalar_group', {
            'xsinx': n_iter * np.sin(n_iter),
            'xcosx': n_iter * np.cos(n_iter),
            'arctanx': np.arctan(n_iter)
        }, n_iter)

    dummy_img = torch.rand(32, 3, 64, 64)  # output from network
    if n_iter % 10 == 0:
        x = vutils.make_grid(dummy_img, normalize=True, scale_each=True)
        writer.add_image('Image', x, n_iter)

        dummy_audio = torch.zeros(sample_rate * 2)
        for i in range(x.size(0)):
            # amplitude of sound should in [-1, 1]
            dummy_audio[i] = np.cos(freqs[n_iter // 10] * np.pi * float(i) /
                                    float(sample_rate))
        writer.add_audio('myAudio',
                         dummy_audio,
                         n_iter,
                         sample_rate=sample_rate)

        writer.add_text('Text', 'text logged at step:' + str(n_iter), n_iter)

        for name, param in resnet18.named_parameters():
            writer.add_histogram(name,
Exemplo n.º 15
0
class SummaryMaker():
    """This class is mostly a wrapper around the tensorboardX SummaryWriter,
    intended to carry the current epoch index, and implement the processing of
    the outputs for more elaborate summaries.
    """

    def __init__(self, log_dir, params, up_factor):
        """
        Args:
            log_dir: (str) The path to the folder where the summary files are
                going to be written. The summary object creates a train and a
                val folders to store the summary files.
            params: (train.utils.Params) The parameters loaded from the
                parameters.json file.
            up_factor: (int) The upscale factor that indicates how much the
                scores maps need to be upscaled to match the original scale
                (used when superposing the embeddings and score maps to the
                input images).

        Attributes:
            writer_train: (tensorboardX.writer.SummaryWriter) The tensorboardX
                writer that writes the training informations.
            writer_val: (tensorboardX.writer.SummaryWriter) The tensorboardX
                writer that writes the validation informations.
            epoch: (int) Stores the current epoch.
            ref_sz: (int) The size in pixels of the reference image.
            srch_sz: (int) The size in pixels of the search image.
            up_factor: (int) The upscale factor. See Args.

        """
        # We use two different summary writers so we can plot both curves in
        # the same plot, as suggested in https://www.quora.com/How-do-you-plot-training-and-validation-loss-on-the-same-graph-using-TensorFlow%E2%80%99s-TensorBoard
        self.writer_train = SummaryWriter(join(log_dir, 'train'))
        self.writer_val = SummaryWriter(join(log_dir, 'val'))
        self.epoch = None
        self.ref_sz = params.reference_sz
        self.srch_sz = params.search_sz
        self.up_factor = up_factor

    def add_epochwise_scalar(self, mode, tag, scalar_value):
        """ Wraps the writer.add_scalar function, using the attribute epoch as
        global_step.

        Args:
            mode: (str) Indicates wheter 'train' or 'val'
            tag: (str) The tag identifying the scalar.
            scalar_value: (scalar) The value of the scalar.
        """
        if mode == 'train':
            self.writer_train.add_scalar(tag, scalar_value, self.epoch)
        elif mode == 'val':
            self.writer_val.add_scalar(tag, scalar_value, self.epoch)

    def add_overlay(self, tag, embed, img, alpha=0.8, cmap='inferno', add_ref=None):
        """ Adds to the summary the images of the input image (ref or search)
        overlayed with the corresponding embedding or correlation map. It expect
        tensors of with dimensions [C x H x W] or [B x C x H x W] if the tensor
        has a batch dimension it takes the FIRST ELEMENT of the batch. The image
        is displayed as fusion of the input image in grayscale and the overlay
        in the chosen color_map, this fusion is controlled by the alpha factor.
        In the case of the embeddings, since there are multiple feature
        channels, we show each of them individually in a grid.
        OBS: The colors represent relative values, where the peak color corresponds
        to the maximum value in any given channel, so no direct value comparisons
        can be made between epochs, only the relative distribution of neighboring
        pixel values, (which should be enough, since we are mosly interested
        in finding the maximum of a given correlation map)

        Args:
            tag: (str) The string identifying the image in tensorboard, images
                with the same tag are grouped together with a slider, and are
                indexed by epoch.
            embed: (torch.Tensor) The tensor containing the embedding of an
                input (ref or search image) or a correlation map (the final
                output). The shape should be [B, C, H, W] or [B, H, W] for the
                case of the correlation map.
            img: (torch.Tensor) The image on top of which the embed is going
                to be overlaid. Reference image embeddings should be overlaid
                on top of reference images and search image embeddings as well
                as the correlation maps should be overlaid on top of the search
                images.
            alpha: (float) A mixing variable, it controls how much of the final
                embedding corresponds to the grayscale input image and how much
                corresponds to the overlay. Alpha = 0, means there is no
                overlay in the final image, only the input image. Conversely,
                Alpha = 1 means there is only overlay. Adjust this value so
                you can distinctly see the overlay details while still seeing
                where it is in relation to the orignal image.
            cmap: (str) The name of the colormap to be used with the overlay.
                The colormaps are defined in the colormaps.py module, but values
                include 'viridis' (greenish blue) and 'inferno' (yellowish red).
            add_ref: (torch.Tensor) Optional. An additional reference image that
                will be plotted to the side of the other images. Useful when
                plotting correlation maps, because it lets the user see both
                the search image and the reference that is used as the target.

        ``Example``
            >>> summ_maker = SummaryMaker(os.path.join(exp_dir, 'tensorboard'), params,
                                           model.upscale_factor)
            ...
            >>> embed_ref = model.get_embedding(ref_img_batch)
            >>> embed_srch = model.get_embedding(search_batch)
            >>> output_batch = model.match_corr(embed_ref, embed_srch)
            >>> batch_index = 0
            >>> summ_maker.add_overlay("Ref_image_{}".format(tbx_index), embed_ref[batch_index], ref_img_batch[batch_index], cmap='inferno')
            >>> summ_maker.add_overlay("Search_image_{}".format(tbx_index), embed_srch[batch_index], search_batch[batch_index], cmap='inferno')
            >>> summ_maker.add_overlay("Correlation_map_{}".format(tbx_index), output_batch[batch_index], search_batch[batch_index], cmap='inferno')
        """
        # TODO Add numbers in the final image to the feature channels.
        # TODO Add the color bar showing the progression of values.
        # If minibatch is given, take only the first image
        # TODO let the user select the image? Loop on all images?
        if len(embed.shape) == 4:
            embed = embed[0]
        if len(img.shape) == 4:
            img = img[0]
        # Normalize the image.
        img = img - img.min()
        img = img/img.max()
        embed = cm.apply_cmap(embed, cmap=cmap)
        # Get grayscale version of image by taking the weighted average of the channels
        # as described in https://www.cs.virginia.edu/~vicente/recognition/notebooks/image_processing_lab.html#2.-Converting-to-Grayscale
        R,G,B = img
        img_gray = 0.21 * R + 0.72 * G + 0.07 * B
        # Get the upscaled size of the embedding, so as to take into account
        # the network's downscale caused by the stride.
        upsc_size = (embed.shape[-1] - 1) * self.up_factor + 1
        embed = F.interpolate(embed, upsc_size, mode='bilinear',
                              align_corners=False)
        # Pad the embedding with zeros to match the image dimensions. We pad
        # all 4 corners equally to keep the embedding centered.
        tot_pad = img.shape[-1] - upsc_size
        # Sanity check 1. The amount of padding must be equal on all sides, so
        # the total padding on any dimension must be an even integer.
        assert tot_pad % 2 == 0, "The embed or image dimensions are incorrect."
        pad = int(tot_pad/2)
        embed = F.pad(embed, (pad, pad, pad, pad), 'constant', 0)
        # Sanity check 2, the size of the embedding in the (H, w) dimensions
        # matches the size of the image.
        assert embed.shape[-2:] == img.shape[-2:], ("The embedding overlay "
                                                    "and image dimensions "
                                                    "do not agree.")
        final_imgs = alpha * embed + (1-alpha) * img_gray
        # The embedding_channel (or feature channel) dimension is treated like
        # a batch dimension, so the grid shows each individual embeding
        # overlayed with the input image. Plus the original image is also shown.
        # If add_ref is used the ref image is the first to be shown.
        img = img.unsqueeze(0)
        final_imgs = torch.cat((img, final_imgs))
        if add_ref is not None:
            # Pads the image if necessary
            pad = int((img.shape[-1] - add_ref.shape[-1])//2)
            add_ref = F.pad(add_ref, (pad, pad, pad, pad), 'constant', 0)
            add_ref = add_ref.unsqueeze(0)
            final_imgs = torch.cat((add_ref, final_imgs))
        final_imgs = make_grid(final_imgs, nrow=6)
        self.writer_val.add_image(tag, final_imgs, self.epoch)
Exemplo n.º 16
0
class TensorBoardImages(Callback):
    """The TensorBoardImages callback will write a selection of images from the validation pass to tensorboard using the
    TensorboardX library and torchvision.utils.make_grid
    """

    def __init__(self, log_dir='./logs',
                 comment='torchbearer',
                 name='Image',
                 key=torchbearer.Y_PRED,
                 write_each_epoch=True,
                 num_images=16,
                 nrow=8,
                 padding=2,
                 normalize=False,
                 range=None,
                 scale_each=False,
                 pad_value=0):
        """Create TensorBoardImages callback which writes images from the given key to the given path. Full name of
        image sub directory will be model name + _ + comment.

        :param log_dir: The tensorboard log path for output
        :type log_dir: str
        :param comment: Descriptive comment to append to path
        :type comment: str
        :param name: The name of the image
        :type name: str
        :param key: The key in state containing image data (tensor of size [c, w, h] or [b, c, w, h])
        :type key: str
        :param write_each_epoch: If True, write data on every epoch, else write only for the first epoch.
        :type write_each_epoch: bool
        :param num_images: The number of images to write
        :type num_images: int
        :param nrow: See `torchvision.utils.make_grid https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid`
        :param padding: See `torchvision.utils.make_grid https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid`
        :param normalize: See `torchvision.utils.make_grid https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid`
        :param range: See `torchvision.utils.make_grid https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid`
        :param scale_each: See `torchvision.utils.make_grid https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid`
        :param pad_value: See `torchvision.utils.make_grid https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid`
        """
        self.log_dir = log_dir
        self.comment = comment
        self.name = name
        self.key = key
        self.write_each_epoch = write_each_epoch
        self.num_images = num_images
        self.nrow = nrow
        self.padding = padding
        self.normalize = normalize
        self.range = range
        self.scale_each = scale_each
        self.pad_value = pad_value

        self._writer = None
        self._data = None
        self.done = False

    def on_start(self, state):
        log_dir = os.path.join(self.log_dir, state[torchbearer.MODEL].__class__.__name__ + '_' + self.comment)
        self._writer = SummaryWriter(log_dir=log_dir)

    def on_step_validation(self, state):
        if not self.done:
            data = state[self.key].clone()

            if len(data.size()) == 3:
                data = data.unsqueeze(1)

            if self._data is None:
                remaining = self.num_images if self.num_images < data.size(0) else data.size(0)

                self._data = data[:remaining].to('cpu')
            else:
                remaining = self.num_images - self._data.size(0)

                if remaining > data.size(0):
                    remaining = data.size(0)

                self._data = torch.cat((self._data, data[:remaining].to('cpu')), dim=0)

            if self._data.size(0) >= self.num_images:
                image = utils.make_grid(
                    self._data,
                    nrow=self.nrow,
                    padding=self.padding,
                    normalize=self.normalize,
                    range=self.range,
                    scale_each=self.scale_each,
                    pad_value=self.pad_value
                )
                self._writer.add_image(self.name, image, state[torchbearer.EPOCH])
                self.done = True
                self._data = None

    def on_end_epoch(self, state):
        if self.write_each_epoch:
            self.done = False

    def on_end(self, state):
        self._writer.close()
Exemplo n.º 17
0
    def train(self):

        self.save_config()
        print("\n# Training", file=sys.stderr)
        args = self.args
        model, gen_opt, gen_scheduler = self.model, self.gen_opt, self.gen_scheduler

        if args.logdir:
            from tensorboardX import SummaryWriter

            writer = SummaryWriter(args.logdir)
        else:
            writer = None

        step = 1

        for epoch in range(args.epochs):

            iterator = tqdm(self.get_batcher(self.train_loader))

            for x_mb, y_mb in iterator:
                # [B, H*W]
                x_mb = x_mb.reshape(-1, args.height * args.width)
                # [B, 10]
                context = y_mb.float() if args.conditional else None
                model.train()
                gen_opt.zero_grad()

                if args.num_masks == 1:
                    resample_mask = False
                else:  # training with variable masks
                    resample_mask = (args.resample_mask_every > 0
                                     and step % args.resample_mask_every == 0)

                # [B, H*W]
                noisy_x = torch.where(
                    torch.rand_like(x_mb) > args.input_dropout,
                    x_mb,
                    torch.zeros_like(x_mb),
                )
                p_x = model(inputs=context,
                            history=noisy_x,
                            resample_mask=resample_mask)
                # [B, H*W]
                ll_mb = p_x.log_prob(x_mb)
                # [B]
                ll = ll_mb.sum(-1)

                loss = -(ll).mean()
                loss.backward()
                gen_opt.step()

                display = OrderedDict()
                display["0s"] = "{:.2f}".format(
                    (x_mb == 0).float().mean().item())
                display["1s"] = "{:.2f}".format(
                    (x_mb == 1).float().mean().item())
                display["NLL"] = "{:.2f}".format(-ll.mean().item())

                if writer:
                    writer.add_scalar("training/LL", ll)
                    writer.add_image("training/posterior/sample",
                                     z.mean(0).reshape(1, 1, -1) * 255)

                iterator.set_postfix(display, refresh=False)
                step += 1

            stop, dict_valid = validate(
                self.get_batcher(self.valid_loader),
                args,
                model,
                gen_opt,
                gen_scheduler,
                writer=writer,
                name="dev",
            )

            if stop:
                print("Early stopping at epoch {:3}/{}".format(
                    epoch + 1, args.epochs))
                break

            print("Epoch {:3}/{} -- ".format(epoch + 1, args.epochs) +
                  ", ".join([
                      "{}: {:4.2f}".format(k, v)
                      for k, v in sorted(dict_valid.items())
                  ]))

        print("Loading best model...")
        self.load()
        print("Validation results")
        val_dict = self.validate()
        print(
            "dev",
            " ".join([
                "{}={:4.2f}".format(k, v) for k, v in sorted(val_dict.items())
            ]),
        )
class Train:
    __device = []
    __writer = []
    __model = []
    __transformations = []
    __dataset_train = []
    __train_loader = []
    __loss_func = []
    __optimizer = []
    __exp_lr_scheduler = []

    def __init__(self, gpu='0'):
        # Device configuration
        self.__device = torch.device('cuda:'+gpu if torch.cuda.is_available() else 'cpu')
        self.__writer = SummaryWriter('logs')
        self.__model = CNNDriver()
        # Set model to train mode
        self.__model.train()
        print(self.__model)
        self.__writer.add_graph(self.__model, torch.rand(10, 3, 66, 200))
        # Put model on GPU
        self.__model = self.__model.to(self.__device)

    def train(self, num_epochs=100, batch_size=400, lr=0.0001, l2_norm=0.001, save_dir='./save', input='./DataLMDB'):
        # Create log/save directory if it does not exist
        if not os.path.exists('./logs'):
            os.makedirs('./logs')
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        self.__transformations = transforms.Compose([AugmentDrivingTransform(), 
                                                     RandomBrightness(), ConvertToGray(), 
                                                     ConvertToSepia(), AddNoise(), DrivingDataToTensor(),])
        self.__dataset_train = DriveData_LMDB(input, self.__transformations)
        self.__train_loader = DataLoader(self.__dataset_train, batch_size=batch_size, shuffle=True, num_workers=4)

        # Loss and Optimizer
        self.__loss_func = nn.MSELoss()
        # self.__loss_func = nn.SmoothL1Loss()
        self.__optimizer = torch.optim.Adam(self.__model.parameters(), lr=lr, weight_decay=l2_norm)

        # Decay LR by a factor of 0.1 every 10 epochs
        self.__exp_lr_scheduler = lr_scheduler.StepLR(self.__optimizer, step_size=15, gamma=0.1)

        print('Train size:', len(self.__dataset_train), 'Batch size:', batch_size)
        print('Batches per epoch:', len(self.__dataset_train) // batch_size)

        # Train the Model
        iteration_count = 0
        for epoch in range(num_epochs):
            for batch_idx, samples in enumerate(self.__train_loader):

                # Send inputs/labels to GPU
                images = samples['image'].to(self.__device)
                labels = samples['label'].to(self.__device)

                self.__optimizer.zero_grad()

                # Forward + Backward + Optimize
                outputs = self.__model(images)
                loss = self.__loss_func(outputs, labels.unsqueeze(dim=1))

                loss.backward()
                self.__optimizer.step()
                self.__exp_lr_scheduler.step(epoch)

                # Send loss to tensorboard
                self.__writer.add_scalar('loss/', loss.item(), iteration_count)
                self.__writer.add_histogram('steering_out', outputs.clone().detach().cpu().numpy(), iteration_count, bins='doane')
                self.__writer.add_histogram('steering_in', 
                                            labels.unsqueeze(dim=1).clone().detach().cpu().numpy(), iteration_count, bins='doane')

                # Get current learning rate (To display on Tensorboard)
                for param_group in self.__optimizer.param_groups:
                    curr_learning_rate = param_group['lr']
                    self.__writer.add_scalar('learning_rate/', curr_learning_rate, iteration_count)

                # Display on each epoch
                if batch_idx == 0:
                    # Send image to tensorboard
                    self.__writer.add_image('Image', images, epoch)
                    self.__writer.add_text('Steering', 'Steering:' + str(outputs[batch_idx].item()), epoch)
                    # Print Epoch and loss
                    print('Epoch [%d/%d] Loss: %.4f' % (epoch + 1, num_epochs, loss.item()))
                    # Save the Trained Model parameters
                    torch.save(self.__model.state_dict(), save_dir+'/cnn_' + str(epoch) + '.pkl')

                iteration_count += 1
Exemplo n.º 19
0
            msk = deconv(feats)
            msk = functional.upsample(msk, scale_factor=4)
            prior = functional.sigmoid(msk)
            loss += criterion(msk, lbl)

        deconv.zero_grad()
        feature.zero_grad()

        loss.backward()

        optimizer_feature.step()
        optimizer_deconv.step()

        # visulize
        image = make_image_grid(inputs.data[:, :3], mean, std)
        writer.add_image('Image', torchvision.utils.make_grid(image), ib)
        msk = functional.sigmoid(msk)
        mask1 = msk.data

        mask1 = mask1.repeat(1, 3, 1, 1)
        writer.add_image('Image2', torchvision.utils.make_grid(mask1), ib)
        acc = math.e**(0 - loss)
        print('loss: %.4f,  acc %.4f, (epoch: %d, step: %d)' %
              (loss.data[0], acc, it, ib))
        writer.add_scalar('loss', loss.data[0], istep)
        writer.add_scalar('acc', acc.data[0], istep)
        istep += 1

        del inputs, msk, lbl, loss, feats, mask1, image, acc
        gc.collect()
        if ib % 24 == 0:
Exemplo n.º 20
0
class SummaryWorker(multiprocessing.Process):
    def __init__(self, env):
        super(SummaryWorker, self).__init__()
        self.env = env
        self.config = env.config
        self.queue = multiprocessing.Queue()
        try:
            self.timer_scalar = utils.train.Timer(env.config.getfloat('summary', 'scalar'))
        except configparser.NoOptionError:
            self.timer_scalar = lambda: False
        try:
            self.timer_image = utils.train.Timer(env.config.getfloat('summary', 'image'))
        except configparser.NoOptionError:
            self.timer_image = lambda: False
        try:
            self.timer_histogram = utils.train.Timer(env.config.getfloat('summary', 'histogram'))
        except configparser.NoOptionError:
            self.timer_histogram = lambda: False
        with open(os.path.expanduser(os.path.expandvars(env.config.get('summary_histogram', 'parameters'))), 'r') as f:
            self.histogram_parameters = utils.RegexList([line.rstrip() for line in f])
        self.draw_bbox = utils.visualize.DrawBBox(env.config, env.category)
        self.draw_iou = utils.visualize.DrawIou(env.config)

    def __call__(self, name, **kwargs):
        if getattr(self, 'timer_' + name)():
            kwargs = getattr(self, 'copy_' + name)(**kwargs)
            self.queue.put((name, kwargs))

    def stop(self):
        self.queue.put((None, {}))

    def run(self):
        self.writer = SummaryWriter(os.path.join(self.env.model_dir, self.env.args.run))
        while True:
            name, kwargs = self.queue.get()
            if name is None:
                break
            func = getattr(self, 'summary_' + name)
            try:
                func(**kwargs)
            except:
                traceback.print_exc()

    def copy_scalar(self, **kwargs):
        step, loss_total, loss, loss_hparam = (kwargs[key] for key in 'step, loss_total, loss, loss_hparam'.split(', '))
        loss_total = loss_total.data.clone().cpu().numpy()
        loss = {key: loss[key].data.clone().cpu().numpy() for key in loss}
        loss_hparam = {key: loss_hparam[key].data.clone().cpu().numpy() for key in loss_hparam}
        return dict(
            step=step,
            loss_total=loss_total,
            loss=loss, loss_hparam=loss_hparam,
        )

    def summary_scalar(self, **kwargs):
        step, loss_total, loss, loss_hparam = (kwargs[key] for key in 'step, loss_total, loss, loss_hparam'.split(', '))
        for key in loss:
            self.writer.add_scalar('loss/' + key, loss[key][0], step)
        if self.config.getboolean('summary_scalar', 'loss_hparam'):
            self.writer.add_scalars('loss_hparam', {key: loss_hparam[key][0] for key in loss_hparam}, step)
        self.writer.add_scalar('loss_total', loss_total[0], step)

    def copy_image(self, **kwargs):
        step, height, width, rows, cols, data, pred, debug = (kwargs[key] for key in 'step, height, width, rows, cols, data, pred, debug'.split(', '))
        data = {key: data[key].clone().cpu().numpy() for key in 'image, yx_min, yx_max, cls'.split(', ')}
        pred = {key: pred[key].data.clone().cpu().numpy() for key in 'yx_min, yx_max, iou, logits'.split(', ') if key in pred}
        matching = (debug['positive'].float() - debug['negative'].float() + 1) / 2
        matching = matching.data.clone().cpu().numpy()
        return dict(
            step=step, height=height, width=width, rows=rows, cols=cols,
            data=data, pred=pred,
            matching=matching,
        )

    def summary_image(self, **kwargs):
        step, height, width, rows, cols, data, pred, matching = (kwargs[key] for key in 'step, height, width, rows, cols, data, pred, matching'.split(', '))
        image = data['image']
        limit = min(self.config.getint('summary_image', 'limit'), image.shape[0])
        image = image[:limit, :, :, :]
        yx_min, yx_max, iou = (pred[key] for key in 'yx_min, yx_max, iou'.split(', '))
        scale = [height / rows, width / cols]
        yx_min, yx_max = (a * scale for a in (yx_min, yx_max))
        if 'logits' in pred:
            cls = np.argmax(F.softmax(torch.autograd.Variable(torch.from_numpy(pred['logits'])), -1).data.cpu().numpy(), -1)
        else:
            cls = np.zeros(iou.shape, np.int)
        if self.config.getboolean('summary_image', 'bbox'):
            # data
            canvas = np.copy(image)
            canvas = pybenchmark.profile('bbox/data')(self.draw_bbox_data)(canvas, *(data[key] for key in 'yx_min, yx_max, cls'.split(', ')))
            self.writer.add_image('bbox/data', torchvision.utils.make_grid(torch.from_numpy(np.stack(canvas)).permute(0, 3, 1, 2).float(), normalize=True, scale_each=True), step)
            # pred
            canvas = np.copy(image)
            canvas = pybenchmark.profile('bbox/pred')(self.draw_bbox_pred)(canvas, yx_min, yx_max, cls, iou, nms=True)
            self.writer.add_image('bbox/pred', torchvision.utils.make_grid(torch.from_numpy(np.stack(canvas)).permute(0, 3, 1, 2).float(), normalize=True, scale_each=True), step)
        if self.config.getboolean('summary_image', 'iou'):
            # bbox
            canvas = np.copy(image)
            canvas_data = self.draw_bbox_data(canvas, *(data[key] for key in 'yx_min, yx_max, cls'.split(', ')), colors=['g'])
            # data
            for i, canvas in enumerate(pybenchmark.profile('iou/data')(self.draw_bbox_iou)(list(map(np.copy, canvas_data)), yx_min, yx_max, cls, matching, rows, cols, colors=['w'])):
                canvas = np.stack(canvas)
                canvas = torch.from_numpy(canvas).permute(0, 3, 1, 2)
                canvas = torchvision.utils.make_grid(canvas.float(), normalize=True, scale_each=True)
                self.writer.add_image('iou/data%d' % i, canvas, step)
            # pred
            for i, canvas in enumerate(pybenchmark.profile('iou/pred')(self.draw_bbox_iou)(list(map(np.copy, canvas_data)), yx_min, yx_max, cls, iou, rows, cols, colors=['w'])):
                canvas = np.stack(canvas)
                canvas = torch.from_numpy(canvas).permute(0, 3, 1, 2)
                canvas = torchvision.utils.make_grid(canvas.float(), normalize=True, scale_each=True)
                self.writer.add_image('iou/pred%d' % i, canvas, step)

    def draw_bbox_data(self, canvas, yx_min, yx_max, cls, colors=None):
        batch_size = len(canvas)
        if len(cls.shape) == len(yx_min.shape):
            cls = np.argmax(cls, -1)
        yx_min, yx_max, cls = ([a[b] for b in range(batch_size)] for a in (yx_min, yx_max, cls))
        return [self.draw_bbox(canvas, yx_min.astype(np.int), yx_max.astype(np.int), cls, colors=colors) for canvas, yx_min, yx_max, cls in zip(canvas, yx_min, yx_max, cls)]

    def draw_bbox_pred(self, canvas, yx_min, yx_max, cls, iou, colors=None, nms=False):
        batch_size = len(canvas)
        mask = iou > self.config.getfloat('detect', 'threshold')
        yx_min, yx_max = (np.reshape(a, [a.shape[0], -1, 2]) for a in (yx_min, yx_max))
        cls, iou, mask = (np.reshape(a, [a.shape[0], -1]) for a in (cls, iou, mask))
        yx_min, yx_max, cls, iou, mask = ([a[b] for b in range(batch_size)] for a in (yx_min, yx_max, cls, iou, mask))
        yx_min, yx_max, cls, iou = ([a[m] for a, m in zip(l, mask)] for l in (yx_min, yx_max, cls, iou))
        if nms:
            overlap = self.config.getfloat('detect', 'overlap')
            keep = [pybenchmark.profile('nms')(utils.postprocess.nms)(torch.Tensor(iou), torch.Tensor(yx_min), torch.Tensor(yx_max), overlap) if iou.shape[0] > 0 else [] for yx_min, yx_max, iou in zip(yx_min, yx_max, iou)]
            keep = [np.array(k, np.int) for k in keep]
            yx_min, yx_max, cls = ([a[k] for a, k in zip(l, keep)] for l in (yx_min, yx_max, cls))
        return [self.draw_bbox(canvas, yx_min.astype(np.int), yx_max.astype(np.int), cls, colors=colors) for canvas, yx_min, yx_max, cls in zip(canvas, yx_min, yx_max, cls)]

    def draw_bbox_iou(self, canvas_share, yx_min, yx_max, cls, iou, rows, cols, colors=None):
        batch_size = len(canvas_share)
        yx_min, yx_max = ([np.squeeze(a, -2) for a in np.split(a, a.shape[-2], -2)] for a in (yx_min, yx_max))
        cls, iou = ([np.squeeze(a, -1) for a in np.split(a, a.shape[-1], -1)] for a in (cls, iou))
        results = []
        for i, (yx_min, yx_max, cls, iou) in enumerate(zip(yx_min, yx_max, cls, iou)):
            mask = iou > self.config.getfloat('detect', 'threshold')
            yx_min, yx_max = (np.reshape(a, [a.shape[0], -1, 2]) for a in (yx_min, yx_max))
            cls, iou, mask = (np.reshape(a, [a.shape[0], -1]) for a in (cls, iou, mask))
            yx_min, yx_max, cls, iou, mask = ([a[b] for b in range(batch_size)] for a in (yx_min, yx_max, cls, iou, mask))
            yx_min, yx_max, cls = ([a[m] for a, m in zip(l, mask)] for l in (yx_min, yx_max, cls))
            canvas = [self.draw_bbox(canvas, yx_min.astype(np.int), yx_max.astype(np.int), cls, colors=colors) for canvas, yx_min, yx_max, cls in zip(np.copy(canvas_share), yx_min, yx_max, cls)]
            iou = [np.reshape(a, [rows, cols]) for a in iou]
            canvas = [self.draw_iou(_canvas, iou) for _canvas, iou in zip(canvas, iou)]
            results.append(canvas)
        return results

    def copy_histogram(self, **kwargs):
        return {key: kwargs[key].data.clone().cpu().numpy() if torch.is_tensor(kwargs[key]) else kwargs[key] for key in 'step, dnn'.split(', ')}


    def summary_histogram(self, **kwargs):
        step, dnn = (kwargs[key] for key in 'step, dnn'.split(', '))
        for name, param in dnn.named_parameters():
            if self.histogram_parameters(name):
                self.writer.add_histogram(name, param, step)
Exemplo n.º 21
0
class SummaryWorker(threading.Thread):
    def __init__(self, env):
        super(SummaryWorker, self).__init__()
        self.config = env.config
        self.running = True
        self.queue = queue.Queue()
        self.writer = SummaryWriter(os.path.join(env.model_dir, env.args.run))
        try:
            self.timer_scalar = utils.train.Timer(env.config.getfloat('summary_secs', 'scalar'))
        except configparser.NoOptionError:
            self.timer_scalar = lambda: False
        try:
            self.timer_image = utils.train.Timer(env.config.getfloat('summary_secs', 'image'))
        except configparser.NoOptionError:
            self.timer_image = lambda: False
        try:
            self.timer_histogram = utils.train.Timer(env.config.getfloat('summary_secs', 'histogram'))
        except configparser.NoOptionError:
            self.timer_histogram = lambda: False
        with open(os.path.expanduser(os.path.expandvars(env.config.get('summary_histogram', 'parameters'))), 'r') as f:
            self.histogram_parameters = utils.RegexList([line.rstrip() for line in f])
        self.draw_bbox = utils.visualize.DrawBBox(env.config, env.category)
        self.draw_iou = utils.visualize.DrawIou(env.config)

    def __call__(self, name, **kwargs):
        if getattr(self, 'timer_' + name)():
            func = getattr(self, 'summary_' + name)
            kwargs = getattr(self, 'copy_' + name)(**kwargs)
            self.queue.put((func, kwargs))

    def stop(self):
        if self.running:
            self.running = False
            self.queue.put((lambda **kwargs: None, {}))

    def run(self):
        while self.running:
            func, kwargs = self.queue.get()
            try:
                func(**kwargs)
            except:
                traceback.print_exc()

    def copy_scalar(self, **kwargs):
        step, loss_total, loss, loss_hparam = (kwargs[key] for key in 'step, loss_total, loss, loss_hparam'.split(', '))
        loss_total = loss_total.data.clone().cpu().numpy()
        loss = {key: loss[key].data.clone().cpu().numpy() for key in loss}
        loss_hparam = {key: loss_hparam[key].data.clone().cpu().numpy() for key in loss_hparam}
        return dict(
            step=step,
            loss_total=loss_total,
            loss=loss, loss_hparam=loss_hparam,
        )

    def summary_scalar(self, **kwargs):
        step, loss_total, loss, loss_hparam = (kwargs[key] for key in 'step, loss_total, loss, loss_hparam'.split(', '))
        for key in loss:
            self.writer.add_scalar('loss/' + key, loss[key][0], step)
        if self.config.getboolean('summary_scalar', 'loss_hparam'):
            self.writer.add_scalars('loss_hparam', {key: loss_hparam[key][0] for key in loss_hparam}, step)
        self.writer.add_scalar('loss_total', loss_total[0], step)

    def copy_image(self, **kwargs):
        step, height, width, rows, cols, data, pred, debug = (kwargs[key] for key in 'step, height, width, rows, cols, data, pred, debug'.split(', '))
        data = {key: data[key].clone().cpu().numpy() for key in 'image, yx_min, yx_max, cls'.split(', ')}
        pred = {key: pred[key].data.clone().cpu().numpy() for key in 'yx_min, yx_max, iou, logits'.split(', ') if key in pred}
        matching = (debug['positive'].float() - debug['negative'].float() + 1) / 2
        matching = matching.data.clone().cpu().numpy()
        return dict(
            step=step, height=height, width=width, rows=rows, cols=cols,
            data=data, pred=pred,
            matching=matching,
        )

    def summary_image(self, **kwargs):
        step, height, width, rows, cols, data, pred, matching = (kwargs[key] for key in 'step, height, width, rows, cols, data, pred, matching'.split(', '))
        image = data['image']
        limit = min(self.config.getint('summary_image', 'limit'), image.shape[0])
        image = image[:limit, :, :, :]
        yx_min, yx_max, iou = (pred[key] for key in 'yx_min, yx_max, iou'.split(', '))
        scale = [height / rows, width / cols]
        yx_min, yx_max = (a * scale for a in (yx_min, yx_max))
        if 'logits' in pred:
            cls = np.argmax(F.softmax(torch.autograd.Variable(torch.from_numpy(pred['logits'])), -1).data.cpu().numpy(), -1)
        else:
            cls = np.zeros(iou.shape, np.int)
        if self.config.getboolean('summary_image', 'bbox'):
            # data
            canvas = np.copy(image)
            canvas = pybenchmark.profile('bbox/data')(self.draw_bbox_data)(canvas, *(data[key] for key in 'yx_min, yx_max, cls'.split(', ')))
            self.writer.add_image('bbox/data', torchvision.utils.make_grid(torch.from_numpy(np.stack(canvas)).permute(0, 3, 1, 2).float(), normalize=True, scale_each=True), step)
            # pred
            canvas = np.copy(image)
            canvas = pybenchmark.profile('bbox/pred')(self.draw_bbox_pred)(canvas, yx_min, yx_max, cls, iou, nms=True)
            self.writer.add_image('bbox/pred', torchvision.utils.make_grid(torch.from_numpy(np.stack(canvas)).permute(0, 3, 1, 2).float(), normalize=True, scale_each=True), step)
        if self.config.getboolean('summary_image', 'iou'):
            # bbox
            canvas = np.copy(image)
            canvas_data = self.draw_bbox_data(canvas, *(data[key] for key in 'yx_min, yx_max, cls'.split(', ')), colors=['g'])
            # data
            for i, canvas in enumerate(pybenchmark.profile('iou/data')(self.draw_bbox_iou)(list(map(np.copy, canvas_data)), yx_min, yx_max, cls, matching, rows, cols, colors=['w'])):
                canvas = np.stack(canvas)
                canvas = torch.from_numpy(canvas).permute(0, 3, 1, 2)
                canvas = torchvision.utils.make_grid(canvas.float(), normalize=True, scale_each=True)
                self.writer.add_image('iou/data%d' % i, canvas, step)
            # pred
            for i, canvas in enumerate(pybenchmark.profile('iou/pred')(self.draw_bbox_iou)(list(map(np.copy, canvas_data)), yx_min, yx_max, cls, iou, rows, cols, colors=['w'])):
                canvas = np.stack(canvas)
                canvas = torch.from_numpy(canvas).permute(0, 3, 1, 2)
                canvas = torchvision.utils.make_grid(canvas.float(), normalize=True, scale_each=True)
                self.writer.add_image('iou/pred%d' % i, canvas, step)

    def draw_bbox_data(self, canvas, yx_min, yx_max, cls, colors=None):
        batch_size = len(canvas)
        if len(cls.shape) == len(yx_min.shape):
            cls = np.argmax(cls, -1)
        yx_min, yx_max, cls = ([a[b] for b in range(batch_size)] for a in (yx_min, yx_max, cls))
        return [self.draw_bbox(canvas, yx_min.astype(np.int), yx_max.astype(np.int), cls, colors=colors) for canvas, yx_min, yx_max, cls in zip(canvas, yx_min, yx_max, cls)]

    def draw_bbox_pred(self, canvas, yx_min, yx_max, cls, iou, colors=None, nms=False):
        batch_size = len(canvas)
        mask = iou > self.config.getfloat('detect', 'threshold')
        yx_min, yx_max = (np.reshape(a, [a.shape[0], -1, 2]) for a in (yx_min, yx_max))
        cls, iou, mask = (np.reshape(a, [a.shape[0], -1]) for a in (cls, iou, mask))
        yx_min, yx_max, cls, iou, mask = ([a[b] for b in range(batch_size)] for a in (yx_min, yx_max, cls, iou, mask))
        yx_min, yx_max, cls, iou = ([a[m] for a, m in zip(l, mask)] for l in (yx_min, yx_max, cls, iou))
        if nms:
            overlap = self.config.getfloat('detect', 'overlap')
            keep = [pybenchmark.profile('nms')(utils.postprocess.nms)(torch.Tensor(yx_min), torch.Tensor(yx_max), torch.Tensor(iou), overlap) if iou.shape[0] > 0 else [] for yx_min, yx_max, iou in zip(yx_min, yx_max, iou)]
            keep = [np.array(k, np.int) for k in keep]
            yx_min, yx_max, cls = ([a[k] for a, k in zip(l, keep)] for l in (yx_min, yx_max, cls))
        return [self.draw_bbox(canvas, yx_min.astype(np.int), yx_max.astype(np.int), cls, colors=colors) for canvas, yx_min, yx_max, cls in zip(canvas, yx_min, yx_max, cls)]

    def draw_bbox_iou(self, canvas_share, yx_min, yx_max, cls, iou, rows, cols, colors=None):
        batch_size = len(canvas_share)
        yx_min, yx_max = ([np.squeeze(a, -2) for a in np.split(a, a.shape[-2], -2)] for a in (yx_min, yx_max))
        cls, iou = ([np.squeeze(a, -1) for a in np.split(a, a.shape[-1], -1)] for a in (cls, iou))
        results = []
        for i, (yx_min, yx_max, cls, iou) in enumerate(zip(yx_min, yx_max, cls, iou)):
            mask = iou > self.config.getfloat('detect', 'threshold')
            yx_min, yx_max = (np.reshape(a, [a.shape[0], -1, 2]) for a in (yx_min, yx_max))
            cls, iou, mask = (np.reshape(a, [a.shape[0], -1]) for a in (cls, iou, mask))
            yx_min, yx_max, cls, iou, mask = ([a[b] for b in range(batch_size)] for a in (yx_min, yx_max, cls, iou, mask))
            yx_min, yx_max, cls = ([a[m] for a, m in zip(l, mask)] for l in (yx_min, yx_max, cls))
            canvas = [self.draw_bbox(canvas, yx_min.astype(np.int), yx_max.astype(np.int), cls, colors=colors) for canvas, yx_min, yx_max, cls in zip(np.copy(canvas_share), yx_min, yx_max, cls)]
            iou = [np.reshape(a, [rows, cols]) for a in iou]
            canvas = [self.draw_iou(_canvas, iou) for _canvas, iou in zip(canvas, iou)]
            results.append(canvas)
        return results

    def copy_histogram(self, **kwargs):
        return {key: kwargs[key].data.clone().cpu().numpy() if torch.is_tensor(kwargs[key]) else kwargs[key] for key in 'step, dnn'.split(', ')}


    def summary_histogram(self, **kwargs):
        step, dnn = (kwargs[key] for key in 'step, dnn'.split(', '))
        for name, param in dnn.named_parameters():
            if self.histogram_parameters(name):
                self.writer.add_histogram(name, param, step)
Exemplo n.º 22
0
class Stats:

    def __init__(self, log_dir=None):
        self.test_noise = None
        self.writer = SummaryWriter(log_dir=log_dir)
        self.input_shape = None
        self.last_notification = None
        os.makedirs("%s/images" % self.writer.file_writer.get_logdir())
        shutil.copy("./evolution/config.py", self.writer.file_writer.get_logdir())  # copy setup file into log dir

    def save_data(self, epoch, g_pop, d_pop, save_best_model=False):
        if not os.path.isfile(os.path.join(self.writer.file_writer.get_logdir(), "./generator_noise.pt")):
            shutil.copy("./generator_noise.pt", self.writer.file_writer.get_logdir())  # copy noise file into log dir

        epoch_dir = f"{self.writer.file_writer.get_logdir()}/generations/{epoch:03d}"
        os.makedirs(epoch_dir)
        global_data_values = {}
        for name, pop in [("d", d_pop), ("g", g_pop)]:
            phenotypes = pop.phenotypes()
            global_data_values[f"species_{name}"] = len(pop.species_list)
            global_data_values[f"speciation_threshold_{name}"] = pop.speciation_threshold
            global_data_values[f"invalid_{name}"] = sum([p.invalid for p in phenotypes])
            # generate data for current generation
            columns = ["loss", "trained_samples", "layers", "genes_used",
                       "model", "species_index", "fitness", "generation", "age"]
            if name == "g":
                columns.append("fid_score")
                columns.append("inception_score")
                columns.append("rmse_score")
            df = pd.DataFrame(index=np.arange(0, len(phenotypes)), columns=columns)
            j = 0
            for i, species in enumerate(pop.species_list):
                for p in species:
                    values = [p.error, p.trained_samples, len(p.genome.genes), np.mean([g.used for g in p.genome.genes]),
                              p.to_json(), i, p.fitness(), p.genome.generation, p.genome.age]
                    if name == "g":
                        values.append(p.fid_score)
                        values.append(p.inception_score_mean)
                        values.append(p.rmse_score)
                    df.loc[j] = values
                    j += 1
            df.sort_values('fitness').reset_index(drop=True).to_csv(f"{epoch_dir}/data_{name}.csv")

        # generate image for each G
        os.makedirs(f"{epoch_dir}/images")
        for i, g in enumerate(g_pop.sorted()):
            if not g.valid():
                continue
            self.generate_image(g, path=f"{epoch_dir}/images/generator-{i:03d}.png")
            if i == 0 and save_best_model:
                g.save(f"{epoch_dir}/generator.pkl")

        if save_best_model:
            d_pop.sorted()[0].save(f"{epoch_dir}/discriminator.pkl")

        # append values into global data
        global_data = pd.DataFrame(data=global_data_values, index=[epoch])
        with open(f"{self.writer.file_writer.get_logdir()}/generations/data.csv", 'a') as f:
            global_data.to_csv(f, header=epoch == 0)

    def generate_image(self, G, path=None):
        if not G.valid():
            return None
        test_images = G(self.test_noise).detach()
        grid_images = [torch.from_numpy((test_images[k, :].data.cpu().numpy().reshape(self.input_shape) + 1)/2)
                        for k in range(config.stats.num_generated_samples)]
        grid = vutils.make_grid(grid_images, normalize=False, nrow=int(config.stats.num_generated_samples**(1/2)))
        # store grid images in the run folder
        if path is not None:
            vutils.save_image(grid, path)
        return grid

    def generate(self, input_shape, g_pop, d_pop, epoch, num_epochs, train_loader, validation_loader):
        if epoch % config.stats.print_interval != 0 and epoch != num_epochs - 1:
            return

        generators = g_pop.sorted()
        discriminators = d_pop.sorted()
        G = g_pop.best()
        D = d_pop.best()
        G.eval()
        D.eval()

        # this should never ocurr!
        if G.invalid or D.invalid:
            logger.error("invalid D or G")
            return

        self.input_shape = input_shape
        if self.test_noise is None:
            self.test_noise = G.generate_noise(config.stats.num_generated_samples, volatile=True).cpu()
            # display noise only once
            # grid_noise = vutils.make_grid(self.test_noise.data, normalize=True, scale_each=True, nrow=4)
            # self.writer.add_image('Image/Noise', grid_noise)

        if config.stats.calc_rmse_score:
            rmse_score.initialize(train_loader, config.evolution.fitness.fid_sample_size)
            for g in generators:
                g.calc_rmse_score()

        if config.stats.calc_inception_score:
            for g in generators:
                g.inception_score()
            self.writer.add_scalars('Training/Inception_score', {"Best_G": G.inception_score_mean}, epoch)
        if G.fid_score is not None:
            self.writer.add_scalars('Training/Fid_score', {"Best_G": G.fid_score}, epoch)

        self.save_data(epoch, g_pop, d_pop, config.stats.save_best_model and (epoch == num_epochs-1 or epoch % config.stats.save_best_interval == 0))

        self.writer.add_scalars('Training/Trained_samples', {"Best_D": D.trained_samples,
                                                             "Best_G": G.trained_samples,
                                                             "D": sum([p.trained_samples for p in discriminators])/len(discriminators),
                                                             "G": sum([p.trained_samples for p in generators])/len(generators)
                                                             }, epoch)
        self.writer.add_scalars('Training/Loss', {"Best_D": D.error, "Best_G": G.error}, epoch)
        self.writer.add_scalars('Training/Fitness', {"Best_D": D.fitness(), "Best_G": G.fitness()}, epoch)
        self.writer.add_scalars('Training/Generation', {"Best_D": D.genome.generation, "Best_G": G.genome.generation}, epoch)
        self.writer.add_histogram('Training/Loss/D', np.array([p.error for p in discriminators]), epoch)
        self.writer.add_histogram('Training/Loss/G', np.array([p.error for p in generators]), epoch)
        self.writer.add_histogram('Training/Trained_samples/D', np.array([p.trained_samples for p in discriminators]), epoch)
        self.writer.add_histogram('Training/Trained_samples/G', np.array([p.trained_samples for p in generators]), epoch)

        # generate images with the best perfomings G's
        for i, gen in enumerate(generators[:config.stats.print_best_amount]):
            image_path = None
            if i == 0:
                image_path = '%s/images/generated-%05d.png' % (self.writer.file_writer.get_logdir(), epoch)
            grid = self.generate_image(gen, path=image_path)
            self.writer.add_image('Image/Best_G/%d' % i, grid, epoch)

        # write architectures for best G and D
        self.writer.add_text('Graph/Best_G', str([str(p) for p in generators[:config.stats.print_best_amount]]), epoch)
        self.writer.add_text('Graph/Best_D', str([str(p) for p in discriminators[:config.stats.print_best_amount]]), epoch)

        # apply best G and D in the validator dataset
        # FIXME: the validation dataset was already evaluated at this point. Just reuse the data.
        if config.stats.display_validation_stats:
            d_errors_real, d_errors_fake, g_errors = [], [], []
            for n, (images, _) in enumerate(validation_loader):
                images = tools.cuda(Variable(images))
                batch_size = images.size(0)

                d_errors_real.append(D.step_real(images))
                fake_error, _ = D.step_fake(G, batch_size)
                d_errors_fake.append(fake_error)
                g_errors.append(G.step(D, batch_size).data[0])

            # display validation metrics
            self.writer.add_scalars('Validation/D/Loss', {'Real': np.mean(d_errors_real),
                                                          'Fake': np.mean(d_errors_fake)}, epoch)
            self.writer.add_scalars('Validation/Loss', {'Best_D': np.mean(d_errors_real + d_errors_fake),
                                                        'Best_G': np.mean(g_errors)}, epoch)

        # display architecture metrics
        self.writer.add_scalars('Architecture/Layers', {'Best_D': len(D.genome.genes),
                                                        'Best_G': len(G.genome.genes),
                                                        'D': np.mean([len(p.genome.genes) for p in discriminators]),
                                                        'G': np.mean([len(p.genome.genes) for p in generators])
                                                        }, epoch)
        self.writer.add_histogram('Architecture/Layers/D', np.array([len(p.genome.genes) for p in discriminators]), epoch)
        self.writer.add_histogram('Architecture/Layers/G', np.array([len(p.genome.genes) for p in generators]), epoch)

        self.writer.add_scalars('Architecture/Invalid', {'D': sum([p.invalid for p in discriminators]),
                                                         'G': sum([p.invalid for p in generators])
                                                         }, epoch)

        self.writer.add_scalars('Architecture/Species', {"D": len(d_pop.species_list), "G": len(g_pop.species_list)}, epoch)
        self.writer.add_scalars('Architecture/Speciation_Threshold', {"D": int(d_pop.speciation_threshold),
                                                                      "G": int(g_pop.speciation_threshold)}, epoch)

        best_d_used = np.mean([g.used for g in D.genome.genes])
        best_g_used = np.mean([g.used for g in G.genome.genes])
        d_used = np.mean([np.mean([g.used for g in p.genome.genes]) for p in discriminators])
        g_used = np.mean([np.mean([g.used for g in p.genome.genes]) for p in generators])
        self.writer.add_scalars('Architecture/Genes_reuse', {'Best_D': best_d_used, 'Best_G': best_g_used,
                                                             'D': d_used, 'G': g_used}, epoch)

        logger.debug("\n%s: D: %s G: %s", epoch, D.error, G.error)
        logger.debug(G); logger.debug(G.model)
        logger.debug(D); logger.debug(D.model)

        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            logger.debug((f"memory_allocated: {torch.cuda.memory_allocated()}, "
                          f"max_memory_allocated: {torch.cuda.max_memory_allocated()}, "
                          f"memory_cached: {torch.cuda.memory_cached()}, "
                          f"max_memory_cached: {torch.cuda.max_memory_cached()}"))

        if config.stats.notify and \
                (self.last_notification is None or
                 (datetime.now() - self.last_notification).seconds//60 > config.stats.min_notification_interval):
            self.last_notification = datetime.now()
            notify(f"Epoch {epoch}: G {G.fitness():.2f}, D: {D.error:.2f}")

        # graph plotting
        # dummy_input = Variable(torch.randn(28, 28)).cuda()
        # self.writer.add_graph(D.model, (dummy_input, ))
        # dummy_input = Variable(torch.randn(10, 10)).cuda()
        # self.writer.add_graph(G.model, (dummy_input, ))

        # flush writer to avoid memory issues
        self.writer.scalar_dict = {}
Exemplo n.º 23
0
        
        running_scalars['Fake Discriminator Loss 2'] += fake_discriminator_loss2.data[0]
        
        running_scalars['Fake Discriminator Loss 2'] += fake_discriminator_loss3.data[0]
        
        running_scalars['Generator Loss'] += generator_loss.data[0]


######################################## TensorboardX Comparision of images #################################
        running_example_count += images.size()[0]
        if step % settings.summary_step_period == 0 and step != 0:
            comparison_image1 = viewer.create_crowd_images_comparison_grid(cpu(images), cpu(labels),cpu(predicted_labels1))
            comparison_image2 = viewer.create_crowd_images_comparison_grid(cpu(images), cpu(labels),cpu(predicted_labels2))
            comparison_image3 = viewer.create_crowd_images_comparison_grid(cpu(images), cpu(labels),cpu(predicted_labels3))

            summary_writer.add_image('Comparison1', comparison_image1, global_step=step)
            summary_writer.add_image('Comparison2', comparison_image2, global_step=step)
            summary_writer.add_image('Comparison3', comparison_image3, global_step=step)
            
            fake_images_image = torchvision.utils.make_grid(fake_images.data[:9], nrow=3)
            summary_writer.add_image('Fake', fake_images_image, global_step=step)

##################################### TensorboardX Scalars #########################################
            mean_loss = running_scalars['Loss'] / running_example_count
            print('[Epoch: {}, Step: {}] Loss: {:g}'.format(epoch, step, mean_loss))

            for name, running_scalar in running_scalars.items():
                mean_scalar = running_scalar / running_example_count
                summary_writer.add_scalar(name, mean_scalar, global_step=step)
                running_scalars[name] = 0
            running_example_count = 0
Exemplo n.º 24
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', help='GPU to use (0 based)', required=False, default=0, type=int)
    parser.add_argument('--db', help='Database to use during training', required=True, type=str)
    parser.add_argument('--transform_pre', help='Apply transformation before casting to Tensor', type=str)
    parser.add_argument('--model', help='Network model to be trained', required=True, type=str)
    parser.add_argument('--subsample', help='Fraction of image to keep subsampling db', type=float)
    parser.add_argument('--patch_size', help='Patch size', type=int)
    parser.add_argument('--patch_stride', help='Patch stride', type=int)
    parser.add_argument('--num_workers', help='Parallel job in data loading', required=False, type=int)
    parser.add_argument('--batch_size', help='Training image batch size', required=False, type=int)
    parser.add_argument('--lr', help='Learning rate', required=False, type=float)
    parser.add_argument('--n_epochs', help='Number of training epochs', required=False, type=int)
    parser.add_argument('--continue_train', help='resume training from best epoch of #### run', type=str)
    parser.add_argument('--debug', help='Debug flag for visualization', required=False, action='store_true')

    args = parser.parse_args()

    device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() else 'cpu')
    db_name = args.db
    transform_pre = args.transform_pre
    model_name = args.model
    subsample = args.subsample if args.subsample is not None else default_subsample
    patch_size = args.patch_size if args.patch_size is not None else default_patch_size
    patch_stride = args.patch_stride if args.patch_stride is not None else default_patch_stride
    num_workers = args.num_workers if args.num_workers is not None else default_num_workers
    batch_size = args.batch_size if args.batch_size is not None else default_batch_size
    lr = args.lr if args.lr is not None else default_lr
    n_epochs = args.n_epochs if args.n_epochs is not None else default_n_epochs
    continue_train_run = args.continue_train
    debug = args.debug

    # transform function as needed from torch model zoo
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    # Inizialize database and dataloader
    db_class = getattr(db_classes, db_name)
    db = db_class(patch_size=patch_size, patch_stride=patch_stride, transform_pre=transform_pre, transform_post=normalize, subsample=subsample)
    db.generate_split(train_size=default_train_size)
    dl_train = DataLoader(db, batch_size=batch_size, num_workers=num_workers, shuffle=True, drop_last=True)

    # initialize network
    print('Loading network weights')
    model = None
    if model_name == 'resnet':
        model = models.resnet18(pretrained=False)
        model.load_state_dict(model_zoo.load_url(resnet_model_urls['resnet18'], model_dir=model_path))

        # craft network last layer to be a two-class classifier
        num_ftrs = model.fc.in_features
        model.fc = torch.nn.Linear(num_ftrs, 1)

        # model = ResNet(BasicBlock, [2, 2, 2])
        # model.load_state_dict(model_zoo.load_url(resnet_model_urls['resnet18'], model_dir=model_path), strict=False)

    elif model_name == 'alexnet':
        model = models.alexnet(pretrained=False)
        model.load_state_dict(model_zoo.load_url(alexnet_model_urls['alexnet'], model_dir=model_path))

        # craft network last layer to be a two-class classifier
        num_ftrs = model.classifier._modules['6'].in_features
        model.classifier._modules['6'] = torch.nn.Linear(num_ftrs, 1)

    elif model_name == 'vgg':
        # model = models.vgg16_bn(pretrained=False, init_weights=False)
        # model.load_state_dict(model_zoo.load_url(vgg_model_urls['vgg16'], model_dir=model_path))
        model = models.vgg16_bn(pretrained=True)

        # craft network last layer to be a two-class classifier
        num_ftrs = model.classifier._modules['6'].in_features
        model.classifier._modules['6'] = torch.nn.Linear(num_ftrs, 1)

    # Load weights
    if continue_train_run is not None:
        run_folder = glob(os.path.join(runs_path, '*-{}'.format(continue_train_run)))[0]
        try:
            model.load_state_dict(torch.load(os.path.join(run_folder, 'model_best.pth')))
        except FileNotFoundError:
            print('No models weights found in {}.\nTraining from scratch...'.format(run_folder))

    # craft network last layer to be a two-class classifier
    # num_ftrs = model.fc.in_features
    # model.fc = torch.nn.Linear(num_ftrs, np.prod(patch_size[:2]))
    model.to(device)

    # define criterion
    criterion = torch.nn.BCELoss()
    s = torch.nn.Sigmoid()

    # define optimizer and lr decay
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

    if debug:
        # Prepare Tensorboard writer
        writer_dir = os.path.join(runs_path,
                                  model.__class__.__name__ + '_' +
                                  '{}'.format(db.__class__.__name__) + '_' +
                                  'transform_{}'.format(transform_pre) + '_' +
                                  'subsample_{}'.format(subsample) + '_' +
                                  'patch_{}'.format(patch_size) + '_' +
                                  'stride_{}'.format(patch_stride) + '_' +
                                  'batch_{}'.format(batch_size) + '_' +
                                  'lr_{}'.format(lr) + '_' +
                                  'nepochs_{}'.format(n_epochs))
        run_name = [random.choice(string.ascii_letters + string.digits) for _ in range(6)]
        writer_dir += '-' + ''.join(run_name)
        while os.path.exists(writer_dir):
            run_name = [random.choice(string.ascii_letters + string.digits) for _ in range(6)]
            writer_dir = writer_dir.rsplit('-', 1)[0]
            writer_dir += '-' + ''.join(run_name)
        writer = SummaryWriter(writer_dir)
        db.train_db.to_csv(os.path.join(writer_dir, 'train.csv'))
        db.val_db.to_csv(os.path.join(writer_dir, 'val.csv'))
        print('\n\n')
        print('Finetuning model {} on db {}, run {}'.format(model.__class__.__name__, db.__class__.__name__,
                                                                    ''.join(run_name)))
        if transform_pre:
            print('\nApply preprocessing {}'.format(transform_pre))

        print('\n\n')

    min_val_loss = np.inf
    early_stop_counter = 0
    early_stop_flag = False

    for epoch in range(n_epochs):
        if not early_stop_flag:
            # -----------------
            #  Train
            # -----------------
            db.train()
            lr_scheduler.step()
            model.train()
            running_loss_train = 0
            for i_batch, sample_batched in tqdm(enumerate(dl_train), desc='Train Epoch {}'.format(epoch + 1),
                                                total=len(dl_train), unit='batch'):
                overall_iter = int(i_batch + epoch * len(dl_train))

                # load data
                X = sample_batched[0].to(device)
                y = sample_batched[1].to(device)

                # zero the gradients
                optimizer.zero_grad()

                # forward
                y_hat = s(model(X))
                loss = criterion(y_hat, y.view(batch_size, -1))

                # backward
                loss.backward()
                optimizer.step()

                # statistics
                running_loss_train += loss.item()

                if debug and overall_iter % log_period_iter == log_start_iter:
                    writer.add_scalar('loss/train', loss.item() , overall_iter)

            epoch_loss_train = running_loss_train / len(db)
            print('Train Loss: {:.4f}'.format(epoch_loss_train))
            # ---------------------
            #  Validation
            # ---------------------
            db.val()
            model.eval()
            dl_val = DataLoader(db, batch_size=batch_size, num_workers=num_workers, shuffle=True)
            running_loss_val = 0
            for i_batch, sample_batched in tqdm(enumerate(dl_val), desc='Val Epoch {}'.format(epoch + 1),
                                                total=len(dl_val), unit='batch'):

                # load data
                X = sample_batched[0].to(device)
                y = sample_batched[1].to(device)

                # forward
                y_hat = s(model(X))
                loss = criterion(y_hat, y.view(batch_size, -1))

                # statistics
                running_loss_val += loss.item()

            epoch_loss_val = running_loss_val / len(db)
            print('Val Loss: {:.4f}'.format(epoch_loss_val))

            if debug:
                writer.add_scalar('loss/val', epoch_loss_val, overall_iter)
                writer.add_image('X', X[0].detach(), overall_iter)
                # writer.add_image('y', y[0].detach().view(patch_size[0], patch_size[1]))
                # writer.add_image('y_hat', y_hat[0].detach().view(patch_size[0], patch_size[1]))

            if min_val_loss - epoch_loss_val > 1e-4:
                if min_val_loss == np.inf:
                    print('Val_loss {}. \nSaving models...'.format(epoch_loss_val))
                else:
                    print('Val_loss improved by {0:.6f}. \nSaving models...'.format(min_val_loss - epoch_loss_val))
                torch.save(model.state_dict(), os.path.join(writer_dir, 'model_best.pth'))
                min_val_loss = epoch_loss_val

                # save results to csv
                train_results = pd.DataFrame({'model': model.__class__.__name__,
                                              'subsample': subsample,
                                              'patch': patch_size[0],
                                              'stride': patch_stride[0],
                                              'batch': batch_size,
                                              'lr': lr,
                                              'val_loss': epoch_loss_val,
                                              'epoch': epoch + 1,
                                              'run_name': ''.join(run_name)}, index=[0])
                train_results.to_csv(os.path.join(writer_dir, 'train_results.csv'))
                early_stop_counter = 0

            else:
                early_stop_counter += 1

            if debug:
                writer.add_scalar('Epoch', epoch + 1, overall_iter)

            if early_stop_counter == early_stop:
                early_stop_flag = True
                print('\nEarly stopping due to non-decreasing validation loss for {} epochs\n'.format(early_stop))

    return 0
Exemplo n.º 25
0
def train(args, logger, device_ids):
    writer = SummaryWriter()

    logger.info("Loading network")
    #model = AdaMatting(in_channel=4)
    builder = ModelBuilder()
    net_encoder = builder.build_encoder(arch=args.encoder)

    if ('BN' in args.encoder):
        batch_norm = True
    else:
        batch_norm = False
    net_decoder = builder.build_decoder(arch=args.decoder,
                                        batch_norm=batch_norm)

    model = MattingModule(net_encoder, net_decoder)

    model.cuda()
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 betas=(0.9, 0.999),
                                 weight_decay=0.0001)
    if args.resume != "":
        ckpt = torch.load(args.resume)
        model.load_state_dict(ckpt["state_dict"])
        optimizer.load_state_dict(ckpt["optimizer"])
    if args.cuda:
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()
        device = torch.device("cuda:{}".format(device_ids[0]))
        if len(device_ids) > 1:
            logger.info("Loading with multiple GPUs")
            # model = convert_model(model)
            model = model.to(device)
            model = torch.nn.DataParallel(model, device_ids=device_ids)
        else:
            model = model.to(device)
    else:
        device = torch.device("cpu")
        model = model.to(device)

    logger.info("Initializing data loaders")
    train_dataset = AdaMattingDataset(args.raw_data_path, "train")
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=16,
                                               pin_memory=True,
                                               drop_last=True)
    valid_dataset = AdaMattingDataset(args.raw_data_path, "valid")
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=16,
                                               pin_memory=True,
                                               drop_last=True)

    if args.resume != "":
        logger.info("Start training from saved ckpt")
        start_epoch = ckpt["epoch"] + 1
        cur_iter = ckpt["cur_iter"]
        peak_lr = ckpt["peak_lr"]
        best_loss = ckpt["best_loss"]
        best_alpha_loss = ckpt["best_alpha_loss"]
    else:
        logger.info("Start training from scratch")
        start_epoch = 0
        cur_iter = 0
        peak_lr = args.lr
        best_loss = float('inf')
        best_alpha_loss = float('inf')

    max_iter = 43100 * (
        1 - args.valid_portion / 100) / args.batch_size * args.epochs
    tensorboard_iter = cur_iter * (args.batch_size / 16)

    avg_lo = AverageMeter()
    avg_lt = AverageMeter()
    avg_la = AverageMeter()
    for epoch in range(start_epoch, args.epochs):
        # Training
        torch.set_grad_enabled(True)
        model.train()
        for index, (_, inputs, gts) in enumerate(train_loader):
            # cur_lr, peak_lr = lr_scheduler(optimizer=optimizer, cur_iter=cur_iter, peak_lr=peak_lr, end_lr=0.000001,
            #                                decay_iters=args.decay_iters, decay_power=0.8, power=0.5)
            cur_lr = lr_scheduler(optimizer=optimizer,
                                  init_lr=args.lr,
                                  cur_iter=cur_iter,
                                  max_iter=max_iter,
                                  max_decay_times=40,
                                  decay_rate=0.9)

            # img = img.type(torch.FloatTensor).to(device) # [bs, 4, 320, 320]
            inputs = inputs.to(device)
            gt_alpha = (gts[:,
                            0, :, :].unsqueeze(1)).type(torch.FloatTensor).to(
                                device)  # [bs, 1, 320, 320]
            gt_trimap = gts[:, 1, :, :].type(torch.LongTensor).to(
                device)  # [bs, 320, 320]

            optimizer.zero_grad()
            trimap_adaption, t_argmax, alpha_estimation, log_sigma_t_sqr, log_sigma_a_sqr = model(
                inputs)

            L_overall, L_t, L_a = task_uncertainty_loss(
                pred_trimap=trimap_adaption,
                input_trimap_argmax=inputs[:, 3, :, :],
                pred_alpha=alpha_estimation,
                gt_trimap=gt_trimap,
                gt_alpha=gt_alpha,
                log_sigma_t_sqr=log_sigma_t_sqr,
                log_sigma_a_sqr=log_sigma_a_sqr)

            sigma_t, sigma_a = torch.exp(log_sigma_t_sqr.mean() /
                                         2), torch.exp(log_sigma_a_sqr.mean() /
                                                       2)

            optimizer.zero_grad()
            L_overall.backward()
            clip_gradient(optimizer, 5)
            optimizer.step()

            avg_lo.update(L_overall.item())
            avg_lt.update(L_t.item())
            avg_la.update(L_a.item())

            if cur_iter % 10 == 0:
                logger.info(
                    "Epoch: {:03d} | Iter: {:05d}/{} | Loss: {:.4e} | L_t: {:.4e} | L_a: {:.4e}"
                    .format(epoch, index, len(train_loader), avg_lo.avg,
                            avg_lt.avg, avg_la.avg))
                writer.add_scalar("loss/L_overall", avg_lo.avg,
                                  tensorboard_iter)
                writer.add_scalar("loss/L_t", avg_lt.avg, tensorboard_iter)
                writer.add_scalar("loss/L_a", avg_la.avg, tensorboard_iter)
                writer.add_scalar("other/sigma_t", sigma_t.item(),
                                  tensorboard_iter)
                writer.add_scalar("other/sigma_a", sigma_a.item(),
                                  tensorboard_iter)
                writer.add_scalar("other/lr", cur_lr, tensorboard_iter)

                avg_lo.reset()
                avg_lt.reset()
                avg_la.reset()

            cur_iter += 1
            tensorboard_iter = cur_iter * (args.batch_size / 16)

        # Validation
        logger.info("Validating after the {}th epoch".format(epoch))
        avg_loss = AverageMeter()
        avg_l_t = AverageMeter()
        avg_l_a = AverageMeter()
        torch.set_grad_enabled(False)
        model.eval()
        with tqdm(total=len(valid_loader)) as pbar:
            for index, (display_rgb, inputs, gts) in enumerate(valid_loader):
                inputs = inputs.to(device)  # [bs, 4, 320, 320]
                gt_alpha = (gts[:, 0, :, :].unsqueeze(1)).type(
                    torch.FloatTensor).to(device)  # [bs, 1, 320, 320]
                gt_trimap = gts[:, 1, :, :].type(torch.LongTensor).to(
                    device)  # [bs, 320, 320]

                trimap_adaption, t_argmax, alpha_estimation, log_sigma_t_sqr, log_sigma_a_sqr = model(
                    inputs)
                L_overall_valid, L_t_valid, L_a_valid = task_uncertainty_loss(
                    pred_trimap=trimap_adaption,
                    input_trimap_argmax=inputs[:, 3, :, :],
                    pred_alpha=alpha_estimation,
                    gt_trimap=gt_trimap,
                    gt_alpha=gt_alpha,
                    log_sigma_t_sqr=log_sigma_t_sqr,
                    log_sigma_a_sqr=log_sigma_a_sqr)

                avg_loss.update(L_overall_valid.item())
                avg_l_t.update(L_t_valid.item())
                avg_l_a.update(L_a_valid.item())

                if index == 0:
                    input_rbg = torchvision.utils.make_grid(display_rgb,
                                                            normalize=False,
                                                            scale_each=True)
                    writer.add_image('input/rbg_image', input_rbg,
                                     tensorboard_iter)

                    input_trimap = inputs[:, 3, :, :].unsqueeze(dim=1)
                    input_trimap = torchvision.utils.make_grid(input_trimap,
                                                               normalize=False,
                                                               scale_each=True)
                    writer.add_image('input/trimap', input_trimap,
                                     tensorboard_iter)

                    output_alpha = alpha_estimation.clone()
                    output_alpha[t_argmax.unsqueeze(dim=1) == 0] = 0.0
                    output_alpha[t_argmax.unsqueeze(dim=1) == 2] = 1.0
                    output_alpha = torchvision.utils.make_grid(output_alpha,
                                                               normalize=False,
                                                               scale_each=True)
                    writer.add_image('output/alpha', output_alpha,
                                     tensorboard_iter)

                    trimap_adaption_res = (t_argmax.type(torch.FloatTensor) /
                                           2).unsqueeze(dim=1)
                    trimap_adaption_res = torchvision.utils.make_grid(
                        trimap_adaption_res, normalize=False, scale_each=True)
                    writer.add_image('pred/trimap_adaptation',
                                     trimap_adaption_res, tensorboard_iter)

                    alpha_estimation_res = torchvision.utils.make_grid(
                        alpha_estimation, normalize=False, scale_each=True)
                    writer.add_image('pred/alpha_estimation',
                                     alpha_estimation_res, tensorboard_iter)

                    gt_alpha = torchvision.utils.make_grid(gt_alpha,
                                                           normalize=False,
                                                           scale_each=True)
                    writer.add_image('gt/alpha', gt_alpha, tensorboard_iter)

                    gt_trimap = (gt_trimap.type(torch.FloatTensor) /
                                 2).unsqueeze(dim=1)
                    gt_trimap = torchvision.utils.make_grid(gt_trimap,
                                                            normalize=False,
                                                            scale_each=True)
                    writer.add_image('gt/trimap', gt_trimap, tensorboard_iter)

                pbar.update()

        logger.info("Average loss overall: {:.4e}".format(avg_loss.avg))
        logger.info("Average loss of trimap adaptation: {:.4e}".format(
            avg_l_t.avg))
        logger.info("Average loss of alpha estimation: {:.4e}".format(
            avg_l_a.avg))
        writer.add_scalar("valid_loss/L_overall", avg_loss.avg,
                          tensorboard_iter)
        writer.add_scalar("valid_loss/L_t", avg_l_t.avg, tensorboard_iter)
        writer.add_scalar("valid_loss/L_a", avg_l_a.avg, tensorboard_iter)

        is_best = avg_loss.avg < best_loss
        best_loss = min(avg_loss.avg, best_loss)
        is_alpha_best = avg_l_a.avg < best_alpha_loss
        best_alpha_loss = min(avg_l_a.avg, best_alpha_loss)
        if is_best or is_alpha_best or args.save_ckpt:
            if not os.path.exists("ckpts"):
                os.makedirs("ckpts")
            save_checkpoint(ckpt_path=args.raw_data_path,
                            is_best=is_best,
                            is_alpha_best=is_alpha_best,
                            logger=logger,
                            model=model,
                            optimizer=optimizer,
                            epoch=epoch,
                            cur_iter=cur_iter,
                            peak_lr=peak_lr,
                            best_loss=best_loss,
                            best_alpha_loss=best_alpha_loss)

    writer.close()
Exemplo n.º 26
0
class CombinedLogger(object):
    """Combine console and tensorboard logger and record system metrics.
    """
    def __init__(self,
                 name,
                 log_dir,
                 server_env=True,
                 fold="",
                 sysmetrics_interval=2):
        self.pylogger = logging.getLogger(name)
        self.tboard = SummaryWriter(log_dir=log_dir)
        self.times = {}
        self.fold = fold
        # monitor system metrics (cpu, mem, ...)
        if not server_env:
            self.sysmetrics = pd.DataFrame(columns=[
                "global_step", "rel_time", r"CPU (%)", "mem_used (GB)",
                r"mem_used (%)", r"swap_used (GB)", r"gpu_utilization (%)"
            ],
                                           dtype="float16")
            for device in range(torch.cuda.device_count()):
                self.sysmetrics["mem_allocd (GB) by torch on {:10s}".format(
                    torch.cuda.get_device_name(device))] = np.nan
                self.sysmetrics["mem_cached (GB) by torch on {:10s}".format(
                    torch.cuda.get_device_name(device))] = np.nan
            self.sysmetrics_start(sysmetrics_interval)

    def __getattr__(self, attr):
        """delegate all undefined method requests to objects of
        this class in order pylogger, tboard (first find first serve).
        E.g., combinedlogger.add_scalars(...) should trigger self.tboard.add_scalars(...)
        """
        for obj in [self.pylogger, self.tboard]:
            if attr in dir(obj):
                return getattr(obj, attr)
        raise AttributeError("CombinedLogger has no attribute {}".format(attr))

    def time(self, name, toggle=None):
        """record time-spans as with a stopwatch.
        :param name:
        :param toggle: True^=On: start time recording, False^=Off: halt rec. if None determine from current status.
        :return: either start-time or last recorded interval
        """
        if toggle is None:
            if name in self.times.keys():
                toggle = not self.times[name]["toggle"]
            else:
                toggle = True

        if toggle:
            if not name in self.times.keys():
                self.times[name] = {"total": 0, "last": 0}
            elif self.times[name]["toggle"] == toggle:
                print("restarting running stopwatch")
            self.times[name]["last"] = time.time()
            self.times[name]["toggle"] = toggle
            return time.time()
        else:
            if toggle == self.times[name]["toggle"]:
                self.info(
                    "WARNING: tried to stop stopped stop watch: {}.".format(
                        name))
            self.times[name]["last"] = time.time() - self.times[name]["last"]
            self.times[name]["total"] += self.times[name]["last"]
            self.times[name]["toggle"] = toggle
            return self.times[name]["last"]

    def get_time(self, name=None, kind="total", format=None, reset=False):
        """
        :param name:
        :param kind: 'total' or 'last'
        :param format: None for float, "hms"/"ms" for (hours), mins, secs as string
        :param reset: reset time after retrieving
        :return:
        """
        if name is None:
            times = self.times
            if reset:
                self.reset_time()
            return times

        else:
            time = self.times[name][kind]
            if format == "hms":
                m, s = divmod(time, 60)
                h, m = divmod(m, 60)
                time = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(m), int(s))
            elif format == "ms":
                m, s = divmod(time, 60)
                time = "{:02d}m:{:02d}s".format(int(m), int(s))
            if reset:
                self.reset_time(name)
            return time

    def reset_time(self, name=None):
        if name is None:
            self.times = {}
        else:
            del self.times[name]

    def sysmetrics_update(self, global_step=None):
        if global_step is None:
            global_step = time.strftime("%x_%X")
        mem = psutil.virtual_memory()
        mem_used = (mem.total - mem.available)
        gpu_vals = self.gpu_logger.get_vals()
        rel_time = time.time() - self.sysmetrics_start_time
        self.sysmetrics.loc[len(self.sysmetrics)] = [
            global_step, rel_time,
            psutil.cpu_percent(), mem_used / 1024**3,
            mem_used / mem.total * 100,
            psutil.swap_memory().used / 1024**3,
            int(gpu_vals['gpu_graphics_util']), *[
                torch.cuda.memory_allocated(d) / 1024**3
                for d in range(torch.cuda.device_count())
            ], *[
                torch.cuda.memory_cached(d) / 1024**3
                for d in range(torch.cuda.device_count())
            ]
        ]
        return self.sysmetrics.loc[len(self.sysmetrics) - 1].to_dict()

    def sysmetrics2tboard(self, metrics=None, global_step=None, suptitle=None):
        tag = "per_time"
        if metrics is None:
            metrics = self.sysmetrics_update(global_step=global_step)
            tag = "per_epoch"

        if suptitle is not None:
            suptitle = str(suptitle)
        elif self.fold != "":
            suptitle = "Fold_" + str(self.fold)
        if suptitle is not None:
            self.tboard.add_scalars(
                suptitle + "/System_Metrics/" + tag, {
                    k: v
                    for (k, v) in metrics.items()
                    if (k != "global_step" and k != "rel_time")
                }, global_step)

    def sysmetrics_loop(self):
        try:
            os.nice(-19)
        except:
            print("System-metrics logging has no superior process priority.")
        while True:
            metrics = self.sysmetrics_update()
            self.sysmetrics2tboard(metrics, global_step=metrics["rel_time"])
            #print("thread alive", self.thread.is_alive())
            time.sleep(self.sysmetrics_interval)

    def sysmetrics_start(self, interval):
        if interval is not None:
            self.sysmetrics_interval = interval
            self.gpu_logger = Nvidia_GPU_Logger()
            self.sysmetrics_start_time = time.time()
            self.thread = threading.Thread(target=self.sysmetrics_loop)
            self.thread.daemon = True
            self.thread.start()

    def sysmetrics_save(self, out_file):

        self.sysmetrics.to_pickle(out_file)

    def metrics2tboard(self, metrics, global_step=None, suptitle=None):
        """
        :param metrics: {'train': dataframe, 'val':df}, df as produced in
            evaluator.py.evaluate_predictions
        """
        #print("metrics", metrics)
        if global_step is None:
            global_step = len(metrics['train'][list(
                metrics['train'].keys())[0]]) - 1
        if suptitle is not None:
            suptitle = str(suptitle)
        else:
            suptitle = "Fold_" + str(self.fold)

        for key in ['train', 'val']:
            #series = {k:np.array(v[-1]) for (k,v) in metrics[key].items() if not np.isnan(v[-1]) and not 'Bin_Stats' in k}
            loss_series = {}
            unc_series = {}
            bin_stat_series = {}
            mon_met_series = {}
            for tag, val in metrics[key].items():
                val = val[
                    -1]  #maybe remove list wrapping, recording in evaluator?
                if 'bin_stats' in tag.lower() and not np.isnan(val):
                    bin_stat_series["{}".format(tag.split("/")[-1])] = val
                elif 'uncertainty' in tag.lower() and not np.isnan(val):
                    unc_series["{}".format(tag)] = val
                elif 'loss' in tag.lower() and not np.isnan(val):
                    loss_series["{}".format(tag)] = val
                elif not np.isnan(val):
                    mon_met_series["{}".format(tag)] = val

            self.tboard.add_scalars(
                suptitle + "/Binary_Statistics/{}".format(key),
                bin_stat_series, global_step)
            self.tboard.add_scalars(suptitle + "/Uncertainties/{}".format(key),
                                    unc_series, global_step)
            self.tboard.add_scalars(suptitle + "/Losses/{}".format(key),
                                    loss_series, global_step)
            self.tboard.add_scalars(
                suptitle + "/Monitor_Metrics/{}".format(key), mon_met_series,
                global_step)
        self.tboard.add_scalars(suptitle + "/Learning_Rate", metrics["lr"],
                                global_step)
        return

    def batchImgs2tboard(self,
                         batch,
                         results_dict,
                         cmap,
                         boxtype2color,
                         img_bg=False,
                         global_step=None):
        raise NotImplementedError(
            "not up-to-date, problem with importing plotting-file, torchvision dependency."
        )
        if len(batch["seg"].shape) == 5:  #3D imgs
            slice_ix = np.random.randint(batch["seg"].shape[-1])
            seg_gt = plg.to_rgb(batch['seg'][:, 0, :, :, slice_ix], cmap)
            seg_pred = plg.to_rgb(
                results_dict['seg_preds'][:, 0, :, :, slice_ix], cmap)

            mod_img = plg.mod_to_rgb(
                batch["data"][:, 0, :, :, slice_ix]) if img_bg else None

        elif len(batch["seg"].shape) == 4:
            seg_gt = plg.to_rgb(batch['seg'][:, 0, :, :], cmap)
            seg_pred = plg.to_rgb(results_dict['seg_preds'][:, 0, :, :], cmap)
            mod_img = plg.mod_to_rgb(batch["data"][:, 0]) if img_bg else None
        else:
            raise Exception("batch content has wrong format: {}".format(
                batch["seg"].shape))

        #from here on only works in 2D
        seg_gt = np.transpose(seg_gt,
                              axes=(0, 3, 1, 2))  #previous shp: b,x,y,c
        seg_pred = np.transpose(seg_pred, axes=(0, 3, 1, 2))

        seg = np.concatenate((seg_gt, seg_pred), axis=0)
        # todo replace torchvision (tv) dependency
        seg = tv.utils.make_grid(torch.from_numpy(seg), nrow=2)
        self.tboard.add_image("Batch seg, 1st col: gt, 2nd: pred.",
                              seg,
                              global_step=global_step)

        if img_bg:
            bg_img = np.transpose(mod_img, axes=(0, 3, 1, 2))
        else:
            bg_img = seg_gt
        box_imgs = plg.draw_boxes_into_batch(bg_img, results_dict["boxes"],
                                             boxtype2color)
        box_imgs = tv.utils.make_grid(torch.from_numpy(box_imgs), nrow=4)
        self.tboard.add_image("Batch bboxes",
                              box_imgs,
                              global_step=global_step)

        return

    def __del__(
        self
    ):  # otherwise might produce multiple prints e.g. in ipython console
        for hdlr in self.pylogger.handlers:
            hdlr.close()
        self.tboard.close()
        self.pylogger.handlers = []
        del self.pylogger
Exemplo n.º 27
0
class Trainer(object):
    def __init__(self, config, train_data_loader, test_data_loader):
        self.config = config
        self.train_data_loader = train_data_loader
        self.test_data_loader = test_data_loader
        self.start_step = 0
        self.tensorboard = None
        self._build_model()

        if config.num_gpu > 0:
            self.NoiseGenerator = DataParallelWithCallback(
                self.NoiseGenerator.cuda(), device_ids=range(config.num_gpu))
            self.Classifier = DataParallelWithCallback(self.Classifier.cuda(),
                                                       device_ids=range(
                                                           config.num_gpu))

        # # Note: check whether :0 is nessasary or not.
        # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        if config.load_path:
            self._load_model()

        # create the attacker modules
        self.FGSM = get_fgsm(self.config.dataset)
        self.PGD = get_pgd(self.config.dataset)
        self.CW = get_cw(self.config.dataset)

    def _build_model(self):
        noise_channel_size = (3 if self.config.is_rgb else
                              1) * (1 +
                                    (1 if self.config.g_method == 3 else 0) +
                                    (1 if self.config.g_use_grad else 0))

        self.NoiseGenerator = NoiseGenerator(self.config.g_base_channel_dim,
                                             noise_channel_size,
                                             self.config.g_z_dim,
                                             self.config.g_deeper_layer,
                                             self.config.num_classes,
                                             3 if self.config.is_rgb else 1)
        self.Classifier = Classifier(
            num_classes=self.config.num_classes,
            classifier_name=self.config.f_classifier_name,
            dataset=self.config.dataset,
            pretrained=self.config.f_pretrain,
            pretrained_dir=self.config.pretrained_dir)
        self.NoiseGenerator.apply(weights_init_normal)
        if not self.config.f_pretrain:
            self.Classifier.apply(weights_init_normal)

    def _load_model(self):
        print("[*] Load models from {}...".format(self.config.load_path))
        paths = glob(os.path.join(self.config.load_path, 'Classifier_*.pth'))
        paths.sort()

        if len(paths) == 0:
            path = os.path.join(self.config.load_path, 'Classifier.pth')
            if not os.path.exists(path):
                print("[!] No checkpoint found in {}...".format(
                    self.config.load_path))
                return
            self.start_step = 0
        else:
            idxes = [
                int(os.path.basename(path.split('.')[-2].split('_')[-1]))
                for path in paths
            ]
            self.start_step = max(idxes)

        if self.config.num_gpu == 0:
            map_location = lambda storage, loc: storage
        else:
            map_location = None

        if self.config.f_update_style != -1:
            bad_classifier_state = torch.load('{}/Classifier_{}.pth'.format(
                self.config.load_path, self.start_step),
                                              map_location=map_location)
            starts_with_module = False
            for key in bad_classifier_state.keys():
                if key.startswith('module.'):
                    starts_with_module = True
                    break
            if starts_with_module and (self.config.num_gpu < 1):
                correct_classifier_state = {
                    k[7:]: v
                    for k, v in bad_classifier_state.items()
                }
            else:
                correct_classifier_state = bad_classifier_state
            self.Classifier.load_state_dict(correct_classifier_state)

        if self.config.f_update_style != -1:
            bad_generator_state = torch.load('{}/NoiseGen_{}.pth'.format(
                self.config.load_path, self.start_step),
                                             map_location=map_location)
        else:
            bad_generator_state = torch.load('{}/Generator.pth'.format(
                self.config.load_path),
                                             map_location=map_location)

        starts_with_module = False
        for key in bad_generator_state.keys():
            if key.startswith('module.'):
                starts_with_module = True
                break
        if starts_with_module and (self.config.num_gpu < 1):
            correct_generator_state = {
                k[7:]: v
                for k, v in bad_generator_state.items()
            }
        else:
            correct_generator_state = bad_generator_state
        self.NoiseGenerator.load_state_dict(correct_generator_state)

    def _save_model(self, step):
        print("[*] Save models to {}...".format(self.config.model_dir))
        torch.save(self.Classifier.state_dict(),
                   '{}/Classifier_{}.pth'.format(self.config.model_dir, step))
        torch.save(self.NoiseGenerator.state_dict(),
                   '{}/NoiseGen_{}.pth'.format(self.config.model_dir, step))

    def _merge_noise(self, sum_noise, cur_noise, eps_step, eps_all):
        # 0. normalize noise output first: Don't need to, since we always take the tanh output

        # 1. multiply epsilon (with randomness for the training)
        # result: noise is in -eps_step < noise < eps_step
        cur_noise = cur_noise * eps_step

        # 2. return mixed output
        return torch.clamp(sum_noise + cur_noise, -1.0 * eps_all,
                           1.0 * eps_all)

    def _cross_entropy_loss(self,
                            noise_class_output,
                            label,
                            pure_batch,
                            adv_mult=1.0):
        log_prob = F.log_softmax(noise_class_output, dim=1)
        weight = torch.ones_like(label).float()
        weight[pure_batch:] *= adv_mult
        output = F.nll_loss(log_prob, label, reduction='none')
        return torch.mean(weight * output)

    def _compute_acc(self, logits, labels):
        #logits = logits / torch.norm(logits)
        _max_val, max_idx = torch.max(logits, 1)
        return torch.mean(torch.eq(max_idx, labels).double())

    # compute our loss from the output (batch major!)
    def _dsgan_loss(self, noise, output, single_batch, stability=1e-8):
        if noise is None:
            return None

        numerator = torch.mean(torch.abs(output[:single_batch] -
                                         output[single_batch:]),
                               dim=[_ for _ in range(1, len(output.shape))])
        denominator = torch.mean(torch.abs(noise[:single_batch] -
                                           noise[single_batch:]),
                                 dim=[_ for _ in range(1, len(noise.shape))])
        our_term = torch.mean(numerator / (denominator + stability))
        return our_term

    def train(self):
        # Optimizer for G
        if self.config.g_optimizer == 'adam':
            g_optimizer = torch.optim.Adam(
                self.NoiseGenerator.parameters(),
                lr=self.config.g_lr,
                betas=(self.config.g_beta1, self.config.g_beta2),
                weight_decay=self.config.weight_decay)
        elif self.config.g_optimizer == 'sgd':
            g_optimizer = torch.optim.SGD(
                self.NoiseGenerator.parameters(),
                lr=self.config.g_lr,
                momentum=self.config.g_momentum,
                weight_decay=self.config.weight_decay)
        else:
            raise Exception(
                "[!] Optimizer for the generator should be ['adam', 'sgd']")

        # set initial learning rate for the case that it starts training from the middle
        if self.config.f_update_style == 2:
            if self.start_step != 0:
                for group in g_optimizer.param_groups:
                    group.setdefault('initial_lr', self.config.g_lr)
            g_scheduler = torch.optim.lr_scheduler.StepLR(
                g_optimizer,
                step_size=self.config.max_step // 2,
                gamma=self.config.lr_gamma,
                last_epoch=(-1 if self.start_step == 0 else self.start_step))
        else:
            g_scheduler = None

        # Optimizer for F
        if self.config.f_optimizer == 'adam':
            f_optimzer = torch.optim.Adam(
                self.Classifier.parameters(),
                lr=self.config.f_lr,
                betas=(self.config.f_beta1, self.config.f_beta2),
                weight_decay=self.config.weight_decay)
        elif self.config.f_optimizer == 'sgd':
            f_optimizer = torch.optim.SGD(
                self.Classifier.parameters(),
                lr=self.config.f_lr,
                momentum=self.config.f_momentum,
                weight_decay=self.config.weight_decay)
        else:
            raise Exception(
                "[!] Optimizer for the generator should be ['adam', 'sgd']")

        f_scheduler = torch.optim.lr_scheduler.StepLR(
            f_optimizer,
            step_size=self.config.max_step // 2,
            gamma=self.config.lr_gamma,
            last_epoch=(-1 if self.start_step == 0 else self.start_step))
        if self.start_step != 0:
            for group in f_optimizer.param_groups:
                group.setdefault('initial_lr', self.config.f_lr)

        # now load the train data
        loader = iter(self.train_data_loader)

        # train mode
        self.tensorboard = SummaryWriter(self.config.model_dir)
        self.tensorboard.add_text(tag='argument',
                                  text_string=str(self.config.__dict__))
        for step in trange(self.start_step, self.config.max_step, ncols=80):
            try:
                data = loader.next()
            except StopIteration:
                loader = iter(self.train_data_loader)
                data = loader.next()

            # convert unit to float
            real_img = self._get_variable(data[0].type(torch.FloatTensor))
            if (not self.config.is_rgb) and (len(real_img.shape) == 3):
                real_img = torch.unsqueeze(real_img, 1)  # N W H -> N C W H
            label = self._get_variable(data[1].type(torch.LongTensor))
            single_batch_size = label.size(0)

            # try to reduce the learning rate of f
            if f_scheduler is not None:
                f_scheduler.step()
            if g_scheduler is not None:
                g_scheduler.step()

            # MNIST w/ lenet case only:
            # (pretrain 1K steps to make the classifier to be trained)
            # For all the other cases, we've loaded the pretrained hyperparameters
            # If you have a pretrained weights, then you can start from there.
            if (step < 1000) and (self.config.f_classifier_name == 'lenet'):
                self.Classifier.train()
                self.Classifier.zero_grad()
                class_output = self.Classifier(real_img)
                cls_loss = self._cross_entropy_loss(class_output, label,
                                                    single_batch_size)
                cls_loss.backward()
                f_optimizer.step()
                continue

            ######## Phase 1 #######
            # Grab gradients from f before training the G
            # obtain the gradient from the classifier
            self.Classifier.eval()
            self.Classifier.zero_grad()
            grad_input = real_img.detach()
            grad_input.requires_grad = True
            class_output = self.Classifier.forward(grad_input)

            # compute loss
            # f_loss is always being averaged when it computed, so don't need to be re-scaled.
            cls_loss = self._cross_entropy_loss(class_output, label,
                                                single_batch_size)
            # Add other losses here if you want.
            grad_loss = cls_loss

            if self.config.g_use_grad:
                # obtain gradients and disable the gradient for the input
                grad_loss.backward()
                f_grad = grad_input.grad
                # normalized the gradient input
                # Please change it to other normalization if needed
                if self.config.g_normalize_grad:
                    f_grad_norm = f_grad + 1e-15  # DO NOT EDIT! Need a stabilizer in here!!!
                    f_grad = f_grad / f_grad_norm.norm(dim=(2, 3),
                                                       keepdim=True)

                f_grad = f_grad.detach()  # for a sanity check purpose

            ######## Phase 2 ###########
            # Train the generator, not the discriminator.
            # But, discriminator is still required, in order to compute the gradient
            double_real_img = torch.cat((real_img, real_img), 0).detach()
            double_label = torch.cat((label, label), 0).detach()

            if self.config.g_method % 2 == 1:
                double_adv_sum = torch.zeros_like(double_real_img)
            else:
                double_adv_sum = None

            if self.config.g_use_grad:
                double_adv_grad = torch.cat((f_grad, f_grad), 0)
            else:
                double_adv_grad = None

            if self.config.g_z_dim > 0:
                if self.config.num_gpu > 0:
                    g_z = torch.cuda.FloatTensor(
                        single_batch_size * 2, self.config.g_z_dim).normal_()
                else:
                    g_z = torch.FloatTensor(single_batch_size * 2,
                                            self.config.g_z_dim).normal_()
            else:
                g_z = None

            self.NoiseGenerator.train()
            self.Classifier.eval()
            self.NoiseGenerator.zero_grad()
            proxy_loss_sum = 0.0
            dsgan_loss_sum = 0.0
            update_list = []

            if self.config.g_mini_update_style not in [0, 1, 2, 3]:
                raise Exception(
                    "[!] g_mini_update_style should be in [0,1,2,3]")

            for g_iter_step_no in range(self.config.train_g_iter):
                if not self.config.use_cross_entropy_for_g:
                    print(
                        "[!] We cannot train our generator without cross_entropy"
                    )
                    break

                # generate the current pair
                img_grad_noise = double_real_img

                if self.config.g_use_grad:
                    img_grad_noise = torch.cat(
                        (double_real_img, double_adv_grad), 1)

                # in case of recursive gen
                if self.config.g_method == 3:
                    img_grad_noise = torch.cat(
                        (img_grad_noise, double_adv_sum), 1)

                # feed it to the generator
                noise_output_for_g = self.NoiseGenerator(
                    img_grad_noise, double_label, g_z)

                # clamping learned noise in epsilon boundary
                if self.config.g_method % 2 == 1:
                    clamp_noise = self._merge_noise(
                        double_adv_sum, noise_output_for_g,
                        self.config.epsilon * self.config.g_ministep_size,
                        self.config.epsilon)
                else:
                    clamp_noise = self.config.epsilon * noise_output_for_g

                # clamping it once again to image boundary
                adv_img_for_g = torch.clamp(
                    double_real_img.detach() + clamp_noise, 0.0, 1.0)

                # first obtain the gradient information for the current result
                copy_for_grad = adv_img_for_g.detach()
                copy_for_grad.requires_grad = True

                # compute & accumulate classification gradients
                if (self.config.g_mini_update_style % 2
                        == 0) or (g_iter_step_no + 1
                                  == self.config.train_g_iter):
                    self.Classifier.zero_grad()
                    noise_class_output_for_g = self.Classifier.forward(
                        adv_img_for_g)
                    # compute gradient for the generator
                    # - cross entropy will increase the confidence quickly
                    # - cw loss can find the attack region effectively (fast).
                    proxy_loss = 0.0
                    if self.config.use_cross_entropy_for_g:
                        ce_loss = self._cross_entropy_loss(
                            noise_class_output_for_g, double_label,
                            single_batch_size)
                        proxy_loss -= ce_loss

                    # add other losses if you want.
                    proxy_loss_sum += proxy_loss

                # compute & accumulate DSGAN gradient
                if (self.config.g_z_dim > 0) and \
                        ((self.config.g_mini_update_style >= 2) or (g_iter_step_no + 1 == self.config.train_g_iter)):

                    # compute our loss and add it to g_loss
                    dsgan_magnitude = self._dsgan_loss(g_z, adv_img_for_g,
                                                       single_batch_size)
                    if self.config.dsgan_lambda > 0.0:
                        dsgan_loss = -1.0 * self.config.dsgan_lambda * dsgan_magnitude
                    else:
                        dsgan_loss = 0.0
                    dsgan_loss_sum += dsgan_loss

                # preparing for the next mini-step
                if g_iter_step_no + 1 != self.config.train_g_iter:
                    if self.config.g_use_grad:
                        # compute gradient information for the next time step
                        self.Classifier.zero_grad()
                        grad_output_for_g = self.Classifier.forward(
                            copy_for_grad)
                        grad_ce_loss = self._cross_entropy_loss(
                            grad_output_for_g, double_label, single_batch_size)
                        grad_loss = grad_ce_loss
                        grad_loss.backward()

                        # obtain gradients and disable the gradient for the input
                        f_grad = copy_for_grad.grad
                        # normalized the gradient input.
                        # Please change it to other normalization if needed
                        if self.config.g_normalize_grad:
                            f_grad_norm = f_grad + 1e-15  # DO NOT EDIT! Need a stabilizer in here!!!
                            f_grad = f_grad / f_grad_norm.norm(dim=(2, 3),
                                                               keepdim=True)
                        double_adv_grad = f_grad.detach()

                    if double_adv_sum is not None:
                        double_adv_sum = clamp_noise

                update_list.append(adv_img_for_g.detach())

            g_loss_sum = proxy_loss_sum + dsgan_loss_sum
            g_loss_sum.backward()
            # https://github.com/pytorch/examples/blob/master/word_language_model/main.py
            nn.utils.clip_grad_norm_(self.NoiseGenerator.parameters(), 1.0)
            g_optimizer.step()

            ######## Phase 3 #########
            # train the Discriminator
            if self.config.f_update_style == 1:
                # merge update
                f_label_list = [torch.cat((label, label, label), 0)]
                f_update_list = [torch.cat((real_img, update_list[-1]), 0)]

            elif self.config.f_update_style == 2:
                # update false labels first
                # then update the true label
                f_label_list = [double_label, label]
                f_update_list = [update_list[-1], real_img]
            elif self.config.f_update_style == -1:
                # finetune our generator only
                if (step % self.config.save_step) == (self.config.save_step -
                                                      1):
                    self._save_model(step)
                    self.defence_regular_eval(iter_step=step)
                continue
            else:
                raise Exception(
                    "[!] f_update_style should be [1: single, 2: twice]")

            self.Classifier.train()
            noise_class_output_for_debugging = None
            noise_class_loss_for_debugging = None
            real_pred_sum = 0.0
            fake_pred_sum = 0.0
            for image_for_f, label_for_f in zip(f_update_list, f_label_list):
                self.Classifier.zero_grad()

                noise_class_output = self.Classifier(image_for_f)
                if noise_class_output_for_debugging is None:
                    noise_class_output_for_debugging = noise_class_output

                cls_loss = self._cross_entropy_loss(noise_class_output,
                                                    label_for_f,
                                                    single_batch_size)
                if noise_class_loss_for_debugging is None:
                    noise_class_loss_for_debugging = cls_loss

                f_loss = cls_loss
                # update the classifier and the generator
                f_loss.backward()
                # https://github.com/pytorch/examples/blob/master/word_language_model/main.py
                nn.utils.clip_grad_norm_(self.Classifier.parameters(), 1.0)
                f_optimizer.step()

            ######## Logging ##########
            f_acc = self._compute_acc(class_output[-single_batch_size:],
                                      label).data
            self.tensorboard.add_scalar('train/f_loss', f_loss.data, step)
            self.tensorboard.add_scalar('train/f_acc', f_acc, step)
            if self.config.dsgan_lambda > 0.0:
                self.tensorboard.add_scalar('train/lambda',
                                            self.config.dsgan_lambda, step)

            fg_acc = self._compute_acc(
                noise_class_output_for_debugging[-single_batch_size * 2:],
                double_label).data
            fg_loss = self._cross_entropy_loss(
                noise_class_output_for_debugging[-single_batch_size * 2:],
                double_label, single_batch_size)
            self.tensorboard.add_scalar('train/fg_cls_loss', fg_loss.data,
                                        step)
            self.tensorboard.add_scalar('train/fg_acc', fg_acc, step)
            if step % self.config.log_step == 0:
                print("")
                print("[{}/{}] Acc_F: {:.4f} F_loss: {:.4f} Acc_FG: {:.4f} cls_loss: {:.4f}".\
                       format(step, self.config.max_step, f_acc, f_loss.data, fg_acc, fg_loss.data))

            if self.config.train_g_iter > 0:
                self.tensorboard.add_scalar('train/proxy_loss_sum',
                                            proxy_loss_sum.data, step)
                self.tensorboard.add_scalar('train/proxy_loss_last',
                                            proxy_loss.data, step)
                if (self.config.g_z_dim > 0) and (dsgan_magnitude is not None):
                    if self.config.dsgan_lambda > 0.0:
                        self.tensorboard.add_scalar('train/dsgan_loss_sum',
                                                    dsgan_loss_sum.data, step)
                    self.tensorboard.add_scalar('train/dsgan_loss_last',
                                                dsgan_magnitude.data, step)
                    if step % self.config.log_step == 0:
                        print("[{}/{}] our_loss: {:.4f} proxy_loss: {:.4f}". \
                              format(step, self.config.max_step, dsgan_magnitude.data, proxy_loss.data))
                self.tensorboard.add_scalar('train/g_loss_sum',
                                            g_loss_sum.data, step)

            # save checkpoints and noise image
            if step % self.config.save_step == self.config.save_step - 1:

                if self.config.g_use_grad:
                    slice1 = update_list[-1][:single_batch_size]
                    slice2 = update_list[-1][single_batch_size:]

                    grad_abs = torch.abs(double_adv_grad)
                    grad_min = torch.min(grad_abs)
                    grad_rescale = grad_abs - grad_min
                    grad_max = torch.max(grad_rescale)
                    grad_rescale /= grad_max
                    grad_slice1 = grad_rescale[:single_batch_size]
                    grad_slice2 = grad_rescale[single_batch_size:]
                    self.tensorboard.add_image(
                        'train/pair1',
                        tvutils.make_grid(torch.cat(
                            (real_img[:15], slice1[:15], slice2[:15],
                             grad_slice1[:15], grad_slice2[:15]), 0),
                                          nrow=15), step)
                    self.tensorboard.add_image(
                        'train/pair2',
                        tvutils.make_grid(torch.cat(
                            (real_img[15:30], slice1[15:30], slice2[15:30],
                             grad_slice1[15:30], grad_slice2[15:30]), 0),
                                          nrow=15), step)

                self._save_model(step)
                self.defence_regular_eval(iter_step=step)

    def _test_classifier(self,
                         image_tensor,
                         label_tensor,
                         iter_step=0,
                         method_name='PGD'):
        total_acc_f = []
        num_items = len(label_tensor)
        self.Classifier.eval()
        for index in range(0, num_items, self.config.single_batch_size):
            # first slice into batch
            adv_img = image_tensor[
                index:min(index + self.config.single_batch_size, num_items)]
            label = label_tensor[
                index:min(index + self.config.single_batch_size, num_items)]

            # run classifier
            logits = self.Classifier.forward(adv_img)

            # get accuracy
            acc_f = self._compute_acc(logits, label)
            total_acc_f.append(acc_f.data)

        # aggregate the performance
        performance = sum(total_acc_f) / len(total_acc_f)
        print("[{} / {}] Acc: {:.4f}".format(method_name, iter_step,
                                             performance))

        if self.tensorboard is not None:
            self.tensorboard.add_scalar('test/{}_acc'.format(method_name),
                                        performance, iter_step)

    def get_sample_pdf_of_checkpoint(self, default_z_iter=10):
        loader = iter(self.test_data_loader)

        test_dir = os.path.join(self.config.model_dir, 'test')
        if not os.path.exists(test_dir):
            os.makedirs(test_dir)

        self.Classifier.eval()
        self.NoiseGenerator.eval()
        total_acc_f = []
        total_acc_g = []
        real_img_arr = []
        real_label_arr = []
        adv_img_arr = []
        adv_att_arr = []

        for step in trange(len(self.test_data_loader), ncols=80):
            try:
                data = loader.next()
            except StopIteration:
                print("[!] Test sample generation finished. Samples are in {}".
                      format(test_dir))
                break

            real_img = self._get_variable(data[0].type(torch.FloatTensor))
            if (not self.config.is_rgb) and (len(real_img.shape) == 3):
                real_img = torch.unsqueeze(real_img, 1)

            label = self._get_variable(data[1].type(torch.LongTensor))
            single_batch_size = label.size(0)

            ######## Phase 1 #######
            # Grab gradient from f before training the G
            self.Classifier.zero_grad()
            grad_input = real_img.detach()
            grad_input.requires_grad = True
            class_output = self.Classifier.forward(grad_input)

            # compute loss
            f_loss = self._cross_entropy_loss(class_output, label,
                                              single_batch_size)
            f_loss.backward()

            if self.config.g_use_grad:
                f_grad = grad_input.grad
                if self.config.g_normalize_grad:
                    f_grad_norm = f_grad + 1e-15  # DO NOT EDIT! Need a stabilizer in here!!!
                    f_grad = f_grad / f_grad_norm.norm(dim=(2, 3),
                                                       keepdim=True)

            # Phase 2 #
            num_iter_z = default_z_iter if self.config.g_z_dim > 0 else 1
            adv_img_inner_arr = []
            adv_att_inner_arr = []
            for _ in range(num_iter_z):
                adv_grad = f_grad.detach()

                if self.config.g_method % 2 == 1:
                    adv_sum = torch.zeros_like(real_img)
                else:
                    adv_sum = None

                if self.config.g_z_dim > 0:
                    if self.config.num_gpu > 0:
                        g_z = torch.cuda.FloatTensor(
                            single_batch_size, self.config.g_z_dim).normal_()
                    else:
                        g_z = torch.FloatTensor(single_batch_size,
                                                self.config.g_z_dim).normal_()
                else:
                    g_z = None

                self.NoiseGenerator.zero_grad()
                for g_iter_step_no in range(self.config.train_g_iter):
                    img_grad_noise = real_img
                    if self.config.g_use_grad:
                        img_grad_noise = torch.cat((img_grad_noise, adv_grad),
                                                   1)

                    if self.config.g_method == 3:
                        img_grad_noise = torch.cat((img_grad_noise, adv_sum),
                                                   1)

                    # feed it to the generator
                    noise_output = self.NoiseGenerator.forward(
                        img_grad_noise, label, g_z)

                    # generate learned noise
                    if self.config.g_method % 2 == 1:
                        clamp_noise = self._merge_noise(
                            adv_sum, noise_output,
                            self.config.epsilon * self.config.g_ministep_size,
                            self.config.epsilon)
                    else:
                        clamp_noise = self.config.epsilon * noise_output
                    adv_img_for_g = torch.clamp(
                        real_img.detach() + clamp_noise, 0.0, 1.0)
                    copy_for_grad = adv_img_for_g.detach()
                    copy_for_grad.requires_grad = True

                    # preparing for the next mini-step
                    # Note: we are not updating the Generator.
                    if g_iter_step_no + 1 != self.config.train_g_iter:
                        if self.config.g_use_grad:
                            self.Classifier.zero_grad()
                            grad_output_for_g = self.Classifier.forward(
                                copy_for_grad)
                            grad_ce_loss = self._cross_entropy_loss(
                                grad_output_for_g, label, single_batch_size)
                            grad_loss = grad_ce_loss
                            grad_loss.backward()

                            # obtain gradients and disable the gradient for the input
                            f_inner_grad = copy_for_grad.grad
                            if self.config.g_normalize_grad:
                                f_inner_grad_norm = f_inner_grad + 1e-15  # DO NOT EDIT! Need a stabilizer in here!!!
                                f_inner_grad = f_inner_grad / f_inner_grad_norm.norm(
                                    dim=(2, 3), keepdim=True)
                            adv_grad = f_inner_grad.detach()

                        if adv_sum is not None:
                            adv_sum = clamp_noise

                    # generate learned noise
                    target_image = adv_img_for_g.detach()
                    target_attack = clamp_noise.detach()

                adv_img_inner_arr.append(target_image.detach().data)
                adv_att_inner_arr.append(target_attack.detach().data)

            self.Classifier.zero_grad()
            class_output = self.Classifier.forward(real_img)
            noise_class_output = self.Classifier.forward(target_image)
            acc_f = self._compute_acc(class_output, label)
            acc_g = self._compute_acc(noise_class_output, label)
            total_acc_f.append(acc_f.data)
            total_acc_g.append(acc_g.data)

            real_img_arr.append(real_img.unsqueeze(1).detach().data)
            real_label_arr.append(label.data)
            adv_img_arr.append(
                torch.transpose(torch.stack(adv_img_inner_arr), 0, 1))
            adv_att_arr.append(
                torch.transpose(torch.stack(adv_att_inner_arr), 0, 1))

        print("[{}] Acc_F: {:.4f}, Acc_FG: {}".format(
            test_dir,
            sum(total_acc_f) / len(total_acc_f),
            sum(total_acc_g) / len(total_acc_g)))

        print("Converting the results into numpy format.")
        real_img_arr = torch.cat(real_img_arr, 0)
        orig_data_cpu = real_img_arr.mul(255).clamp(0, 255).byte().permute(
            0, 1, 3, 4, 2).cpu().numpy()

        real_label_arr = torch.cat(real_label_arr, 0)
        orig_label_cpu = real_label_arr.to(dtype=torch.int16).cpu().numpy()

        adv_img_arr = torch.cat(adv_img_arr, 0)
        adv_img_cpu = adv_img_arr.mul(255).clamp(0, 255).byte().permute(
            0, 1, 3, 4, 2).cpu().numpy()

        adv_att_arr = torch.clamp(
            (1.0 + torch.cat(adv_att_arr, 0) / self.config.epsilon) / 2.0, 0.0,
            1.0)
        adv_att_cpu = adv_att_arr.mul(255).clamp(0, 255).byte().permute(
            0, 1, 3, 4, 2).cpu().numpy()

        print("start generating a pdf file")
        item_dict_for_pdf = {}
        for real, label, img, att in zip(orig_data_cpu, orig_label_cpu,
                                         adv_img_cpu, adv_att_cpu):
            current_std = np.reshape(att[4:], [6, -1])
            current_std = np.expand_dims(current_std, 1) - np.expand_dims(
                current_std, 0)
            current_std = np.mean(
                np.sum(current_std * current_std, axis=-1) * (1 - np.eye(6)))

            temp_arr = np.concatenate([img[4:], att[4:]], axis=0)
            temp_arr = np.transpose(temp_arr, (1, 0, 2, 3))
            shape = temp_arr.shape

            if (label not in item_dict_for_pdf) or (item_dict_for_pdf[label][0]
                                                    < current_std):
                if shape[3] == 1:
                    item_dict_for_pdf[label] = [
                        current_std,
                        np.reshape(temp_arr, (shape[0], shape[1] * shape[2]))
                    ]
                else:
                    item_dict_for_pdf[label] = [
                        current_std,
                        np.reshape(temp_arr,
                                   (shape[0], shape[1] * shape[2], shape[3]))
                    ]

        sorted_list = [
            item_dict_for_pdf[_][1] for _ in range(self.config.num_classes)
        ]
        output = np.concatenate(sorted_list, axis=0)

        print("start saving it in {} as vis_{}.pdf".format(
            self.config.log_dir, self.config.model_name))
        import scipy.misc
        scipy.misc.imsave(
            os.path.join(self.config.log_dir,
                         'vis_{}.pdf'.format(self.config.model_name)), output)

    def _run_single_attack(self, iter_step=0, method_name='PGD'):
        # set test dir to save
        test_dir = os.path.join(self.config.model_dir, 'test')
        if not os.path.exists(test_dir):
            os.makedirs(test_dir)

        # set a new data_loader
        loader = iter(self.test_data_loader)
        steps_required_per_epoch = len(loader)
        if method_name.endswith('_slow'):
            steps_required_per_epoch = 5

        print(steps_required_per_epoch)
        print('[Info] Start running {} for step {}'.format(
            method_name, iter_step))

        # run attack mechanism
        output_list = []
        target_list = []
        self.Classifier.eval()
        for step in range(steps_required_per_epoch):
            try:
                data = loader.next()
            except StopIteration:
                loader = iter(self.test_data_loader)
                data = loader.next()

            # convert unit to float
            input_img = self._get_variable(data[0].type(torch.FloatTensor))
            if (not self.config.is_rgb) and (len(input_img.shape) == 3):
                input_img = torch.unsqueeze(input_img, 1)
            target_label = self._get_variable(data[1].type(torch.LongTensor))
            single_batch_size = target_label.size(0)

            if method_name == 'FGSM':
                adv_result = run_fgsm(self.FGSM, self.Classifier, input_img,
                                      target_label, self.config.epsilon)
            elif method_name == 'PGD':
                adv_result = run_pgd(self.PGD, self.Classifier, input_img,
                                     target_label, self.config.epsilon,
                                     self.config.test_iter_steps)
            elif method_name == 'CW':
                adv_result = run_cw(self.CW, self.Classifier, input_img,
                                    target_label)
            elif method_name == 'ORIGINAL':
                adv_result = input_img

            output_list.append(adv_result)
            target_list.append(target_label)

        output_tensor = torch.cat(output_list, dim=0)
        label_tensor = torch.cat(target_list, dim=0)

        if self.config.test_save_adv:
            np.save(
                '{}/attack_{}_step{}_img.npy'.format(test_dir, method_name,
                                                     iter_step),
                output_tensor.permute(0, 2, 3, 1).cpu().numpy())
            np.save(
                '{}/attack_{}_step{}_label.npy'.format(test_dir, method_name,
                                                       iter_step),
                label_tensor.permute(0, 2, 3, 1).cpu().numpy())

        self._test_classifier(output_tensor, label_tensor, iter_step,
                              method_name)

    def defence_regular_eval(self, iter_step=0):
        # set classifier to be in evaluation mode
        self.Classifier.eval()

        self._run_single_attack(iter_step, 'FGSM')
        self._run_single_attack(iter_step, 'PGD')
        # self._run_single_attack(iter_step, 'CW')
        self._run_single_attack(iter_step, 'ORIGINAL')

        # return back to train mode
        self.Classifier.train()
        return

    def defence_over_cnw(self, iter_step=0):
        # assume model is loaded properly
        self.Classifier.eval()

        self._run_single_attack(iter_step, 'CW')

        self.Classifier.train()
        return

    def _get_variable(self, inputs):
        if self.config.num_gpu > 0:
            out = Variable(inputs.cuda())
        else:
            out = Variable(inputs)
        return out
def train(args, snapshot_path):
    num_classes = 2
    base_lr = args.base_lr
    train_data_path = args.root_path
    batch_size = args.batch_size
    max_iterations = args.max_iterations

    net = unet_3D(n_classes=num_classes, in_channels=1)
    model = net.cuda()
    DAN = FC3DDiscriminator(num_classes=num_classes)
    DAN = DAN.cuda()

    db_train = BraTS2019(base_dir=train_data_path,
                         split='train',
                         num=None,
                         transform=transforms.Compose([
                             RandomRotFlip(),
                             RandomCrop(args.patch_size),
                             ToTensor(),
                         ]))

    def worker_init_fn(worker_id):
        random.seed(args.seed + worker_id)

    labeled_idxs = list(range(0, args.labeled_num))
    unlabeled_idxs = list(range(args.labeled_num, 250))
    batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs,
                                          batch_size,
                                          batch_size - args.labeled_bs)

    trainloader = DataLoader(db_train,
                             batch_sampler=batch_sampler,
                             num_workers=4,
                             pin_memory=True,
                             worker_init_fn=worker_init_fn)

    model.train()

    optimizer = optim.SGD(model.parameters(),
                          lr=base_lr,
                          momentum=0.9,
                          weight_decay=0.0001)
    DAN_optimizer = optim.Adam(DAN.parameters(),
                               lr=args.DAN_lr,
                               betas=(0.9, 0.99))
    ce_loss = CrossEntropyLoss()
    dice_loss = losses.DiceLoss(2)

    writer = SummaryWriter(snapshot_path + '/log')
    logging.info("{} iterations per epoch".format(len(trainloader)))

    iter_num = 0
    max_epoch = max_iterations // len(trainloader) + 1
    best_performance = 0.0
    iterator = tqdm(range(max_epoch), ncols=70)
    for epoch_num in iterator:
        for i_batch, sampled_batch in enumerate(trainloader):

            volume_batch, label_batch = sampled_batch['image'], sampled_batch[
                'label']
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()

            DAN_target = torch.tensor([1, 1, 0, 0]).cuda()
            model.train()
            DAN.eval()

            outputs = model(volume_batch)
            outputs_soft = torch.softmax(outputs, dim=1)

            loss_ce = ce_loss(outputs[:args.labeled_bs],
                              label_batch[:args.labeled_bs][:])
            loss_dice = dice_loss(outputs_soft[:args.labeled_bs],
                                  label_batch[:args.labeled_bs].unsqueeze(1))
            supervised_loss = 0.5 * (loss_dice + loss_ce)

            consistency_weight = get_current_consistency_weight(iter_num //
                                                                150)
            DAN_outputs = DAN(outputs_soft[args.labeled_bs:],
                              volume_batch[args.labeled_bs:])

            consistency_loss = F.cross_entropy(
                DAN_outputs, (DAN_target[:args.labeled_bs]).long())
            loss = supervised_loss + consistency_weight * consistency_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            model.eval()
            DAN.train()
            with torch.no_grad():
                outputs = model(volume_batch)
                outputs_soft = torch.softmax(outputs, dim=1)

            DAN_outputs = DAN(outputs_soft, volume_batch)
            DAN_loss = F.cross_entropy(DAN_outputs, DAN_target.long())
            DAN_optimizer.zero_grad()
            DAN_loss.backward()
            DAN_optimizer.step()

            lr_ = base_lr * (1.0 - iter_num / max_iterations)**0.9
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr_

            iter_num = iter_num + 1
            writer.add_scalar('info/lr', lr_, iter_num)
            writer.add_scalar('info/total_loss', loss, iter_num)
            writer.add_scalar('info/loss_ce', loss_ce, iter_num)
            writer.add_scalar('info/loss_dice', loss_dice, iter_num)
            writer.add_scalar('info/consistency_loss', consistency_loss,
                              iter_num)
            writer.add_scalar('info/consistency_weight', consistency_weight,
                              iter_num)

            logging.info(
                'iteration %d : loss : %f, loss_ce: %f, loss_dice: %f' %
                (iter_num, loss.item(), loss_ce.item(), loss_dice.item()))

            if iter_num % 20 == 0:
                image = volume_batch[0, 0:1, :, :,
                                     20:61:10].permute(3, 0, 1,
                                                       2).repeat(1, 3, 1, 1)
                grid_image = make_grid(image, 5, normalize=True)
                writer.add_image('train/Image', grid_image, iter_num)

                image = outputs_soft[0, 1:2, :, :,
                                     20:61:10].permute(3, 0, 1,
                                                       2).repeat(1, 3, 1, 1)
                grid_image = make_grid(image, 5, normalize=False)
                writer.add_image('train/Predicted_label', grid_image, iter_num)

                image = label_batch[0, :, :, 20:61:10].unsqueeze(0).permute(
                    3, 0, 1, 2).repeat(1, 3, 1, 1)
                grid_image = make_grid(image, 5, normalize=False)
                writer.add_image('train/Groundtruth_label', grid_image,
                                 iter_num)

            if iter_num > 0 and iter_num % 200 == 0:
                model.eval()
                avg_metric = test_all_case(model,
                                           args.root_path,
                                           test_list="val.txt",
                                           num_classes=2,
                                           patch_size=args.patch_size,
                                           stride_xy=64,
                                           stride_z=64)
                if avg_metric[:, 0].mean() > best_performance:
                    best_performance = avg_metric[:, 0].mean()
                    save_mode_path = os.path.join(
                        snapshot_path, 'iter_{}_dice_{}.pth'.format(
                            iter_num, round(best_performance, 4)))
                    save_best = os.path.join(
                        snapshot_path, '{}_best_model.pth'.format(args.model))
                    torch.save(model.state_dict(), save_mode_path)
                    torch.save(model.state_dict(), save_best)

                writer.add_scalar('info/val_dice_score', avg_metric[0, 0],
                                  iter_num)
                writer.add_scalar('info/val_hd95', avg_metric[0, 1], iter_num)
                logging.info('iteration %d : dice_score : %f hd95 : %f' %
                             (iter_num, avg_metric[0, 0].mean(),
                              avg_metric[0, 1].mean()))
                model.train()

            if iter_num % 3000 == 0:
                save_mode_path = os.path.join(snapshot_path,
                                              'iter_' + str(iter_num) + '.pth')
                torch.save(model.state_dict(), save_mode_path)
                logging.info("save model to {}".format(save_mode_path))

            if iter_num >= max_iterations:
                break
        if iter_num >= max_iterations:
            iterator.close()
            break
    writer.close()
    return "Training Finished!"
Exemplo n.º 29
0
class face_learner(object):
    def __init__(self, conf, inference=False):
        if conf.use_mobilfacenet:
            self.model = MobileFaceNet(conf.embedding_size).to(conf.device)
            print('MobileFaceNet model generated')
        else:
            self.model = Backbone(conf.net_depth, conf.drop_ratio,
                                  conf.net_mode).to(conf.device)
            self.growup = GrowUP().to(conf.device)
            self.discriminator = Discriminator().to(conf.device)
            print('{}_{} model generated'.format(conf.net_mode,
                                                 conf.net_depth))

        if not inference:

            self.milestones = conf.milestones
            self.loader, self.class_num = get_train_loader(conf)
            if conf.discriminator:
                self.child_loader, self.adult_loader = get_train_loader_d(conf)

            os.makedirs(conf.log_path, exist_ok=True)
            self.writer = SummaryWriter(conf.log_path)
            self.step = 0

            self.head = Arcface(embedding_size=conf.embedding_size,
                                classnum=self.class_num).to(conf.device)

            # Will not use anymore
            if conf.use_dp:
                self.model = nn.DataParallel(self.model)
                self.head = nn.DataParallel(self.head)

            print(self.class_num)
            print(conf)

            print('two model heads generated')

            paras_only_bn, paras_wo_bn = separate_bn_paras(self.model)

            if conf.use_mobilfacenet:
                self.optimizer = optim.SGD(
                    [{
                        'params': paras_wo_bn[:-1],
                        'weight_decay': 4e-5
                    }, {
                        'params': [paras_wo_bn[-1]] + [self.head.kernel],
                        'weight_decay': 4e-4
                    }, {
                        'params': paras_only_bn
                    }],
                    lr=conf.lr,
                    momentum=conf.momentum)
            else:
                self.optimizer = optim.SGD(
                    [{
                        'params': paras_wo_bn + [self.head.kernel],
                        'weight_decay': 5e-4
                    }, {
                        'params': paras_only_bn
                    }],
                    lr=conf.lr,
                    momentum=conf.momentum)
            if conf.discriminator:
                self.optimizer_g = optim.Adam(self.growup.parameters(),
                                              lr=1e-4,
                                              betas=(0.5, 0.999))
                self.optimizer_g2 = optim.Adam(self.growup.parameters(),
                                               lr=1e-4,
                                               betas=(0.5, 0.999))
                self.optimizer_d = optim.Adam(self.discriminator.parameters(),
                                              lr=1e-4,
                                              betas=(0.5, 0.999))
                self.optimizer2 = optim.SGD(
                    [{
                        'params': paras_wo_bn + [self.head.kernel],
                        'weight_decay': 5e-4
                    }, {
                        'params': paras_only_bn
                    }],
                    lr=conf.lr,
                    momentum=conf.momentum)

            if conf.finetune_model_path is not None:
                self.optimizer = optim.SGD([{
                    'params': paras_wo_bn,
                    'weight_decay': 5e-4
                }, {
                    'params': paras_only_bn
                }],
                                           lr=conf.lr,
                                           momentum=conf.momentum)
            print('optimizers generated')

            self.board_loss_every = len(self.loader) // 100
            self.evaluate_every = len(self.loader) // 2
            self.save_every = len(self.loader)

            dataset_root = "/home/nas1_userD/yonggyu/Face_dataset/face_emore"
            self.lfw = np.load(
                os.path.join(dataset_root,
                             "lfw_align_112_list.npy")).astype(np.float32)
            self.lfw_issame = np.load(
                os.path.join(dataset_root, "lfw_align_112_label.npy"))
            self.fgnetc = np.load(
                os.path.join(dataset_root,
                             "FGNET_new_align_list.npy")).astype(np.float32)
            self.fgnetc_issame = np.load(
                os.path.join(dataset_root, "FGNET_new_align_label.npy"))
        else:
            # Will not use anymore
            # self.model = nn.DataParallel(self.model)
            self.threshold = conf.threshold

    def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor,
                  negative_wrong, positive_wrong):
        self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy,
                               self.step)
        self.writer.add_scalar('{}_best_threshold'.format(db_name),
                               best_threshold, self.step)
        self.writer.add_scalar('{}_negative_wrong'.format(db_name),
                               negative_wrong, self.step)
        self.writer.add_scalar('{}_positive_wrong'.format(db_name),
                               positive_wrong, self.step)
        self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor,
                              self.step)

    def evaluate(self, conf, carray, issame, nrof_folds=10, tta=True):
        self.model.eval()
        self.growup.eval()
        self.discriminator.eval()
        idx = 0
        embeddings = np.zeros([len(carray), conf.embedding_size])
        with torch.no_grad():
            while idx + conf.batch_size <= len(carray):
                batch = torch.tensor(carray[idx:idx + conf.batch_size])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(
                        batch.to(conf.device)).cpu() + self.model(
                            fliped.to(conf.device)).cpu()
                    embeddings[idx:idx +
                               conf.batch_size] = l2_norm(emb_batch).cpu()
                else:
                    embeddings[idx:idx + conf.batch_size] = self.model(
                        batch.to(conf.device)).cpu()
                idx += conf.batch_size
            if idx < len(carray):
                batch = torch.tensor(carray[idx:])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(
                        batch.to(conf.device)).cpu() + self.model(
                            fliped.to(conf.device)).cpu()
                    embeddings[idx:] = l2_norm(emb_batch).cpu()
                else:
                    embeddings[idx:] = self.model(batch.to(conf.device)).cpu()
        tpr, fpr, accuracy, best_thresholds, dist = evaluate_dist(
            embeddings, issame, nrof_folds)
        buf = gen_plot(fpr, tpr)
        roc_curve = Image.open(buf)
        roc_curve_tensor = transforms.ToTensor()(roc_curve)
        return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor, dist

    def evaluate_child(self, conf, carray, issame, nrof_folds=10, tta=True):
        self.model.eval()
        self.growup.eval()
        self.discriminator.eval()
        idx = 0
        embeddings1 = np.zeros([len(carray) // 2, conf.embedding_size])
        embeddings2 = np.zeros([len(carray) // 2, conf.embedding_size])

        carray1 = carray[::2, ]
        carray2 = carray[1::2, ]

        with torch.no_grad():
            while idx + conf.batch_size <= len(carray1):
                batch = torch.tensor(carray1[idx:idx + conf.batch_size])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.growup(self.model(batch.to(conf.device))).cpu() + \
                                self.growup(self.model(fliped.to(conf.device))).cpu()
                    embeddings1[idx:idx +
                                conf.batch_size] = l2_norm(emb_batch).cpu()
                else:
                    embeddings1[idx:idx + conf.batch_size] = self.growup(
                        self.model(batch.to(conf.device))).cpu()
                idx += conf.batch_size
            if idx < len(carray1):
                batch = torch.tensor(carray1[idx:])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.growup(self.model(batch.to(conf.device))).cpu() + \
                                self.growup(self.model(fliped.to(conf.device))).cpu()
                    embeddings1[idx:] = l2_norm(emb_batch).cpu()
                else:
                    embeddings1[idx:] = self.growup(
                        self.model(batch.to(conf.device))).cpu()

            while idx + conf.batch_size <= len(carray2):
                batch = torch.tensor(carray2[idx:idx + conf.batch_size])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.to(conf.device)).cpu() + \
                                self.model(fliped.to(conf.device)).cpu()
                    embeddings2[idx:idx +
                                conf.batch_size] = l2_norm(emb_batch).cpu()
                else:
                    embeddings2[idx:idx + conf.batch_size] = self.model(
                        batch.to(conf.device)).cpu()
                idx += conf.batch_size
            if idx < len(carray2):
                batch = torch.tensor(carray2[idx:])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.to(conf.device)).cpu() + \
                                self.model(fliped.to(conf.device)).cpu()
                    embeddings2[idx:] = l2_norm(emb_batch).cpu()
                else:
                    embeddings2[idx:] = self.model(batch.to(conf.device)).cpu()

        tpr, fpr, accuracy, best_thresholds = evaluate_child(
            embeddings1, embeddings2, issame, nrof_folds)
        buf = gen_plot(fpr, tpr)
        roc_curve = Image.open(buf)
        roc_curve_tensor = transforms.ToTensor()(roc_curve)
        return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor

    def zero_grad(self):
        self.optimizer.zero_grad()
        self.optimizer_g.zero_grad()
        self.optimizer_d.zero_grad()

    def train(self, conf, epochs):
        self.model.train()
        running_loss = 0.
        for e in range(epochs):
            print('epoch {} started'.format(e))

            if e in self.milestones:
                self.schedule_lr()

            for imgs, labels, ages in tqdm(iter(self.loader)):

                self.optimizer.zero_grad()

                imgs = imgs.to(conf.device)
                labels = labels.to(conf.device)

                embeddings = self.model(imgs)
                thetas = self.head(embeddings, labels)

                loss = conf.ce_loss(thetas, labels)
                loss.backward()
                running_loss += loss.item()

                self.optimizer.step()

                if self.step % self.board_loss_every == 0 and self.step != 0:  # XXX
                    print('tensorboard plotting....')
                    loss_board = running_loss / self.board_loss_every
                    self.writer.add_scalar('train_loss', loss_board, self.step)
                    running_loss = 0.

                # added wrong on evaluations
                if self.step % self.evaluate_every == 0 and self.step != 0:
                    print('evaluating....')
                    # LFW evaluation
                    accuracy, best_threshold, roc_curve_tensor, dist = self.evaluate(
                        conf, self.lfw, self.lfw_issame)
                    # NEGATIVE WRONG
                    wrong_list = np.where((self.lfw_issame == False)
                                          & (dist < best_threshold))[0]
                    negative_wrong = len(wrong_list)
                    # POSITIVE WRONG
                    wrong_list = np.where((self.lfw_issame == True)
                                          & (dist > best_threshold))[0]
                    positive_wrong = len(wrong_list)
                    self.board_val('lfw', accuracy, best_threshold,
                                   roc_curve_tensor, negative_wrong,
                                   positive_wrong)

                    # FGNETC evaluation
                    accuracy2, best_threshold2, roc_curve_tensor2, dist2 = self.evaluate(
                        conf, self.fgnetc, self.fgnetc_issame)
                    # NEGATIVE WRONG
                    wrong_list = np.where((self.fgnetc_issame == False)
                                          & (dist2 < best_threshold2))[0]
                    negative_wrong2 = len(wrong_list)
                    # POSITIVE WRONG
                    wrong_list = np.where((self.fgnetc_issame == True)
                                          & (dist2 > best_threshold2))[0]
                    positive_wrong2 = len(wrong_list)
                    self.board_val('fgent_c', accuracy2, best_threshold2,
                                   roc_curve_tensor2, negative_wrong2,
                                   positive_wrong2)

                    self.model.train()

                if self.step % self.save_every == 0 and self.step != 0:
                    print('saving model....')
                    # save with most recently calculated accuracy?
                    if conf.finetune_model_path is not None:
                        self.save_state(conf, accuracy2, extra=str(conf.data_mode) + '_' + str(conf.net_depth) \
                            + '_' + str(conf.batch_size) + conf.model_name)
                    else:
                        self.save_state(conf, accuracy2, extra=str(conf.data_mode) + '_' + str(conf.net_depth) \
                            + '_' + str(conf.batch_size) + conf.model_name)

                self.step += 1
        print('Horray!')

    def train_with_growup(self, conf, epochs):
        '''
        Our method
        '''
        self.model.train()
        running_loss = 0.
        l1_loss = 0
        for e in range(epochs):
            print('epoch {} started'.format(e))

            if e in self.milestones:
                self.schedule_lr()

            a_loader = iter(self.adult_loader)
            c_loader = iter(self.child_loader)
            for imgs, labels, ages in tqdm(iter(self.loader)):
                # loader : base loader that returns images with id
                # a_loader, c_loader : adult, child loader with same datasize
                # ages : 0 == child, 1== adult
                try:
                    imgs_a, labels_a = next(a_loader)
                    imgs_c, labels_c = next(c_loader)
                except StopIteration:
                    a_loader = iter(self.adult_loader)
                    c_loader = iter(self.child_loader)
                    imgs_a, labels_a = next(a_loader)
                    imgs_c, labels_c = next(c_loader)

                imgs = imgs.to(conf.device)
                labels = labels.to(conf.device)
                imgs_a, labels_a = imgs_a.to(conf.device), labels_a.to(
                    conf.device).type(torch.float32)
                imgs_c, labels_c = imgs_c.to(conf.device), labels_c.to(
                    conf.device).type(torch.float32)
                bs_a = imgs_a.shape[0]

                imgs_ac = torch.cat([imgs_a, imgs_c], dim=0)

                ###########################
                #       Train head        #
                ###########################
                self.optimizer.zero_grad()
                self.optimizer_g2.zero_grad()
                self.growup.train()

                c = (ages == 0)  # select children for enhancement

                embeddings = self.model(imgs)

                if sum(c) > 1:  # there might be no childern in loader's batch
                    embeddings_c = embeddings[c]
                    embeddings_a_hat = self.growup(embeddings_c)
                    embeddings[c] = embeddings_a_hat
                elif sum(c) == 1:
                    self.growup.eval()
                    embeddings_c = embeddings[c]
                    embeddings_a_hat = self.growup(embeddings_c)
                    embeddings[c] = embeddings_a_hat

                thetas = self.head(embeddings, labels)

                loss = conf.ce_loss(thetas, labels)
                loss.backward()
                running_loss += loss.item()
                self.optimizer.step()
                self.optimizer_g2.step()

                ##############################
                #    Train discriminator     #
                ##############################
                self.optimizer_d.zero_grad()
                self.growup.train()
                _embeddings = self.model(imgs_ac)
                embeddings_a, embeddings_c = _embeddings[:bs_a], _embeddings[
                    bs_a:]

                embeddings_a_hat = self.growup(embeddings_c)
                labels_ac = torch.cat([labels_a, labels_c], dim=0)
                pred_a = torch.squeeze(self.discriminator(
                    embeddings_a))  # sperate since batchnorm exists
                pred_c = torch.squeeze(self.discriminator(embeddings_a_hat))
                pred_ac = torch.cat([pred_a, pred_c], dim=0)
                d_loss = conf.ls_loss(pred_ac, labels_ac)
                d_loss.backward()
                self.optimizer_d.step()

                #############################
                #      Train genertator     #
                #############################
                self.optimizer_g.zero_grad()
                embeddings_c = self.model(imgs_c)
                embeddings_a_hat = self.growup(embeddings_c)
                pred_c = torch.squeeze(self.discriminator(embeddings_a_hat))
                labels_a = torch.ones_like(labels_c, dtype=torch.float)
                # generator should make child 1
                g_loss = conf.ls_loss(pred_c, labels_a)

                l1_loss = conf.l1_loss(embeddings_a_hat, embeddings_c)
                g_total_loss = g_loss + 10 * l1_loss
                g_total_loss.backward()

                # g_loss.backward()
                self.optimizer_g.step()

                if self.step % self.board_loss_every == 0 and self.step != 0:  # XXX
                    print('tensorboard plotting....')
                    loss_board = running_loss / self.board_loss_every
                    self.writer.add_scalar('train_loss', loss_board, self.step)
                    self.writer.add_scalar('d_loss', d_loss, self.step)
                    self.writer.add_scalar('g_loss', g_loss, self.step)
                    self.writer.add_scalar('l1_loss', l1_loss, self.step)
                    running_loss = 0.

                if self.step % self.evaluate_every == 0 and self.step != 0:
                    print('evaluating....')
                    accuracy, best_threshold, roc_curve_tensor = self.evaluate(
                        conf, self.lfw, self.lfw_issame)
                    self.board_val('lfw', accuracy, best_threshold,
                                   roc_curve_tensor)
                    accuracy2, best_threshold2, roc_curve_tensor2 = self.evaluate_child(
                        conf, self.fgnetc, self.fgnetc_issame)
                    self.board_val('fgent_c', accuracy2, best_threshold2,
                                   roc_curve_tensor2)

                    self.model.train()

                if self.step % self.save_every == 0 and self.step != 0:
                    print('saving model....')
                    # save with most recently calculated accuracy?
                    self.save_state(conf, accuracy2, extra=str(conf.data_mode) + '_' + str(conf.net_depth) \
                        + '_' + str(conf.batch_size) + conf.model_name)

                self.step += 1
        self.save_state(conf, accuracy2, to_save_folder=True, extra=str(conf.data_mode)  + '_' + str(conf.net_depth)\
             + '_'+ str(conf.batch_size) +'_discriminator_final')

    def train_age_invariant(self, conf, epochs):
        '''
        Our method, without growup
        '''
        self.model.train()
        running_loss = 0.
        l1_loss = 0
        for e in range(epochs):
            print('epoch {} started'.format(e))

            if e in self.milestones:
                self.schedule_lr()
                self.schedule_lr2()

            a_loader = iter(self.adult_loader)
            c_loader = iter(self.child_loader)
            for imgs, labels, ages in tqdm(iter(self.loader)):
                # loader : base loader that returns images with id
                # a_loader, c_loader : adult, child loader with same datasize
                # ages : 0 == child, 1== adult
                try:
                    imgs_a, labels_a = next(a_loader)
                    imgs_c, labels_c = next(c_loader)
                except StopIteration:
                    a_loader = iter(self.adult_loader)
                    c_loader = iter(self.child_loader)
                    imgs_a, labels_a = next(a_loader)
                    imgs_c, labels_c = next(c_loader)

                imgs = imgs.to(conf.device)
                labels = labels.to(conf.device)
                imgs_a, labels_a = imgs_a.to(conf.device), labels_a.to(
                    conf.device).type(torch.float32)
                imgs_c, labels_c = imgs_c.to(conf.device), labels_c.to(
                    conf.device).type(torch.float32)
                bs_a = imgs_a.shape[0]

                imgs_ac = torch.cat([imgs_a, imgs_c], dim=0)

                ###########################
                #       Train head        #
                ###########################
                self.optimizer.zero_grad()

                embeddings = self.model(imgs)

                thetas = self.head(embeddings, labels)

                loss = conf.ce_loss(thetas, labels)
                loss.backward()
                running_loss += loss.item()
                self.optimizer.step()

                ##############################
                #    Train discriminator     #
                ##############################
                self.optimizer_d.zero_grad()
                _embeddings = self.model(imgs_ac)
                embeddings_a, embeddings_c = _embeddings[:bs_a], _embeddings[
                    bs_a:]

                labels_ac = torch.cat([labels_a, labels_c], dim=0)
                pred_a = torch.squeeze(self.discriminator(
                    embeddings_a))  # sperate since batchnorm exists
                pred_c = torch.squeeze(self.discriminator(embeddings_c))
                pred_ac = torch.cat([pred_a, pred_c], dim=0)
                d_loss = conf.ls_loss(pred_ac, labels_ac)
                d_loss.backward()
                self.optimizer_d.step()

                #############################
                #      Train genertator     #
                #############################
                self.optimizer2.zero_grad()
                embeddings_c = self.model(imgs_c)
                pred_c = torch.squeeze(self.discriminator(embeddings_c))
                labels_a = torch.ones_like(labels_c, dtype=torch.float)
                # generator should make child 1
                g_loss = conf.ls_loss(pred_c, labels_a)

                g_loss.backward()
                self.optimizer2.step()

                if self.step % self.board_loss_every == 0 and self.step != 0:  # XXX
                    print('tensorboard plotting....')
                    loss_board = running_loss / self.board_loss_every
                    self.writer.add_scalar('train_loss', loss_board, self.step)
                    self.writer.add_scalar('d_loss', d_loss, self.step)
                    self.writer.add_scalar('g_loss', g_loss, self.step)
                    self.writer.add_scalar('l1_loss', l1_loss, self.step)
                    running_loss = 0.

                if self.step % self.evaluate_every == 0 and self.step != 0:
                    print('evaluating....')
                    accuracy, best_threshold, roc_curve_tensor = self.evaluate(
                        conf, self.lfw, self.lfw_issame)
                    self.board_val('lfw', accuracy, best_threshold,
                                   roc_curve_tensor)
                    accuracy2, best_threshold2, roc_curve_tensor2 = self.evaluate(
                        conf, self.fgnetc, self.fgnetc_issame)
                    self.board_val('fgent_c', accuracy2, best_threshold2,
                                   roc_curve_tensor2)

                    self.model.train()

                if self.step % self.save_every == 0 and self.step != 0:
                    print('saving model....')
                    # save with most recently calculated accuracy?
                    self.save_state(conf, accuracy2, extra=str(conf.data_mode) + '_' + str(conf.net_depth) \
                        + '_' + str(conf.batch_size) + conf.model_name)

                self.step += 1
        self.save_state(conf, accuracy2, to_save_folder=True, extra=str(conf.data_mode)  + '_' + str(conf.net_depth)\
             + '_'+ str(conf.batch_size) +'_discriminator_final')

    def train_age_invariant2(self, conf, epochs):
        '''
        Our method, without growup, using paired dataset TODO
        '''
        self.model.train()
        running_loss = 0.
        l1_loss = 0
        for e in range(epochs):
            print('epoch {} started'.format(e))

            if e in self.milestones:
                self.schedule_lr()
                self.schedule_lr2()

            a_loader = iter(self.adult_loader)
            c_loader = iter(self.child_loader)
            for imgs, labels, ages in tqdm(iter(self.loader)):
                # loader : base loader that returns images with id
                # a_loader, c_loader : adult, child loader with same datasize
                # ages : 0 == child, 1== adult
                try:
                    imgs_a, labels_a = next(a_loader)
                    imgs_c, labels_c = next(c_loader)
                except StopIteration:
                    a_loader = iter(self.adult_loader)
                    c_loader = iter(self.child_loader)
                    imgs_a, labels_a = next(a_loader)
                    imgs_c, labels_c = next(c_loader)

                imgs = imgs.to(conf.device)
                labels = labels.to(conf.device)
                imgs_a, labels_a = imgs_a.to(conf.device), labels_a.to(
                    conf.device).type(torch.float32)
                imgs_c, labels_c = imgs_c.to(conf.device), labels_c.to(
                    conf.device).type(torch.float32)
                bs_a = imgs_a.shape[0]

                imgs_ac = torch.cat([imgs_a, imgs_c], dim=0)

                ###########################
                #       Train head        #
                ###########################
                self.optimizer.zero_grad()

                embeddings = self.model(imgs)

                thetas = self.head(embeddings, labels)

                loss = conf.ce_loss(thetas, labels)
                loss.backward()
                running_loss += loss.item()
                self.optimizer.step()

                ##############################
                #    Train discriminator     #
                ##############################
                self.optimizer_d.zero_grad()
                _embeddings = self.model(imgs_ac)
                embeddings_a, embeddings_c = _embeddings[:bs_a], _embeddings[
                    bs_a:]

                labels_ac = torch.cat([labels_a, labels_c], dim=0)
                pred_a = torch.squeeze(self.discriminator(
                    embeddings_a))  # sperate since batchnorm exists
                pred_c = torch.squeeze(self.discriminator(embeddings_c))
                pred_ac = torch.cat([pred_a, pred_c], dim=0)
                d_loss = conf.ls_loss(pred_ac, labels_ac)
                d_loss.backward()
                self.optimizer_d.step()

                #############################
                #      Train genertator     #
                #############################
                self.optimizer2.zero_grad()
                embeddings_c = self.model(imgs_c)
                pred_c = torch.squeeze(self.discriminator(embeddings_c))
                labels_a = torch.ones_like(labels_c, dtype=torch.float)
                # generator should make child 1
                g_loss = conf.ls_loss(pred_c, labels_a)

                g_loss.backward()
                self.optimizer2.step()

                if self.step % self.board_loss_every == 0 and self.step != 0:  # XXX
                    print('tensorboard plotting....')
                    loss_board = running_loss / self.board_loss_every
                    self.writer.add_scalar('train_loss', loss_board, self.step)
                    self.writer.add_scalar('d_loss', d_loss, self.step)
                    self.writer.add_scalar('g_loss', g_loss, self.step)
                    self.writer.add_scalar('l1_loss', l1_loss, self.step)
                    running_loss = 0.

                if self.step % self.evaluate_every == 0 and self.step != 0:
                    print('evaluating....')
                    accuracy, best_threshold, roc_curve_tensor = self.evaluate(
                        conf, self.lfw, self.lfw_issame)
                    self.board_val('lfw', accuracy, best_threshold,
                                   roc_curve_tensor)
                    accuracy2, best_threshold2, roc_curve_tensor2 = self.evaluate(
                        conf, self.fgnetc, self.fgnetc_issame)
                    self.board_val('fgent_c', accuracy2, best_threshold2,
                                   roc_curve_tensor2)

                    self.model.train()

                if self.step % self.save_every == 0 and self.step != 0:
                    print('saving model....')
                    # save with most recently calculated accuracy?
                    self.save_state(conf, accuracy2, extra=str(conf.data_mode) + '_' + str(conf.net_depth) \
                        + '_' + str(conf.batch_size) + conf.model_name)

                self.step += 1
        self.save_state(conf, accuracy2, to_save_folder=True, extra=str(conf.data_mode)  + '_' + str(conf.net_depth)\
             + '_'+ str(conf.batch_size) +'_discriminator_final')

    def analyze_angle(self, conf, name):
        '''
        Only works on age labeled vgg dataset, agedb dataset
        '''

        angle_table = [{
            0: set(),
            1: set(),
            2: set(),
            3: set(),
            4: set(),
            5: set(),
            6: set(),
            7: set()
        } for i in range(self.class_num)]
        # batch = 0
        # _angle_table = torch.zeros(self.class_num, 8, len(self.loader)//conf.batch_size).to(conf.device)
        if conf.resume_analysis:
            self.loader = []
        for imgs, labels, ages in tqdm(iter(self.loader)):

            imgs = imgs.to(conf.device)
            labels = labels.to(conf.device)
            ages = ages.to(conf.device)

            embeddings = self.model(imgs)
            if conf.use_dp:
                kernel_norm = l2_norm(self.head.module.kernel, axis=0)
                cos_theta = torch.mm(embeddings, kernel_norm)
                cos_theta = cos_theta.clamp(-1, 1)
            else:
                cos_theta = self.head.get_angle(embeddings)

            thetas = torch.abs(torch.rad2deg(torch.acos(cos_theta)))

            for i in range(len(thetas)):
                age_bin = 7
                if ages[i] < 26:
                    age_bin = 0 if ages[i] < 13 else 1 if ages[i] < 19 else 2
                elif ages[i] < 66:
                    age_bin = int(((ages[i] + 4) // 10).item())
                angle_table[labels[i]][age_bin].add(
                    thetas[i][labels[i]].item())

        if conf.resume_analysis:
            with open('analysis/angle_table.pkl', 'rb') as f:
                angle_table = pickle.load(f)
        else:
            with open('analysis/angle_table.pkl', 'wb') as f:
                pickle.dump(angle_table, f)

        count, avg_angle = [], []
        for i in range(self.class_num):
            count.append(
                [len(single_set) for single_set in angle_table[i].values()])
            avg_angle.append([
                sum(list(single_set)) / len(single_set)
                if len(single_set) else 0  # if set() size is zero, avg is zero
                for single_set in angle_table[i].values()
            ])

        count_df = pd.DataFrame(count)
        avg_angle_df = pd.DataFrame(avg_angle)

        with pd.ExcelWriter('analysis/analyze_angle_{}_{}.xlsx'.format(
                conf.data_mode, name)) as writer:
            count_df.to_excel(writer, sheet_name='count')
            avg_angle_df.to_excel(writer, sheet_name='avg_angle')

    def schedule_lr(self):
        for params in self.optimizer.param_groups:
            params['lr'] /= 10
        print(self.optimizer)

    def schedule_lr2(self):
        for params in self.optimizer2.param_groups:
            params['lr'] /= 10
        print(self.optimizer2)

    def infer(self, conf, faces, target_embs, tta=False):
        '''
        faces : list of PIL Image
        target_embs : [n, 512] computed embeddings of faces in facebank
        names : recorded names of faces in facebank
        tta : test time augmentation (hfilp, that's all)
        '''
        embs = []
        for img in faces:
            if tta:
                mirror = transforms.functional.hflip(img)
                emb = self.model(
                    conf.test_transform(img).to(conf.device).unsqueeze(0))
                emb_mirror = self.model(
                    conf.test_transform(mirror).to(conf.device).unsqueeze(0))
                embs.append(l2_norm(emb + emb_mirror))
            else:
                embs.append(
                    self.model(
                        conf.test_transform(img).to(conf.device).unsqueeze(0)))
        source_embs = torch.cat(embs)

        diff = source_embs.unsqueeze(-1) - target_embs.transpose(
            1, 0).unsqueeze(0)
        dist = torch.sum(torch.pow(diff, 2), dim=1)
        minimum, min_idx = torch.min(dist, dim=1)
        min_idx[minimum > self.threshold] = -1  # if no match, set idx to -1
        return min_idx, minimum

    def save_best_state(self,
                        conf,
                        accuracy,
                        to_save_folder=False,
                        extra=None,
                        model_only=False):
        if to_save_folder:
            save_path = conf.save_path
        else:
            save_path = conf.model_path

        os.makedirs('work_space/models', exist_ok=True)
        torch.save(
            self.model.state_dict(),
            str(save_path) +
            ('lfw_best_model_{}_accuracy:{:.3f}_step:{}_{}.pth'.format(
                get_time(), accuracy, self.step, extra)))
        if not model_only:
            torch.save(
                self.head.state_dict(),
                str(save_path) +
                ('lfw_best_head_{}_accuracy:{:.3f}_step:{}_{}.pth'.format(
                    get_time(), accuracy, self.step, extra)))
            torch.save(
                self.optimizer.state_dict(),
                str(save_path) +
                ('lfw_best_optimizer_{}_accuracy:{:.3f}_step:{}_{}.pth'.format(
                    get_time(), accuracy, self.step, extra)))

    def save_state(self,
                   conf,
                   accuracy,
                   to_save_folder=False,
                   extra=None,
                   model_only=False):
        if to_save_folder:
            save_path = conf.save_path
        else:
            save_path = conf.model_path

        os.makedirs('work_space/models', exist_ok=True)
        torch.save(
            self.model.state_dict(),
            str(save_path) +
            ('/model_{}_accuracy:{:.3f}_step:{}_{}.pth'.format(
                get_time(), accuracy, self.step, extra)))
        if not model_only:
            torch.save(
                self.head.state_dict(),
                str(save_path) +
                ('/head_{}_accuracy:{:.3f}_step:{}_{}.pth'.format(
                    get_time(), accuracy, self.step, extra)))
            torch.save(
                self.optimizer.state_dict(),
                str(save_path) +
                ('/optimizer_{}_accuracy:{:.3f}_step:{}_{}.pth'.format(
                    get_time(), accuracy, self.step, extra)))
            if conf.discriminator:
                torch.save(
                    self.growup.state_dict(),
                    str(save_path) +
                    ('/growup_{}_accuracy:{:.3f}_step:{}_{}.pth'.format(
                        get_time(), accuracy, self.step, extra)))

    def load_state(self,
                   conf,
                   fixed_str,
                   from_save_folder=False,
                   model_only=False,
                   analyze=False):
        if from_save_folder:
            save_path = conf.save_path
        else:
            save_path = conf.model_path
        self.model.load_state_dict(
            torch.load(os.path.join(save_path, 'model_{}'.format(fixed_str))))
        if not model_only:
            self.head.load_state_dict(
                torch.load(save_path / 'head_{}'.format(fixed_str)))
            if not analyze:
                self.optimizer.load_state_dict(
                    torch.load(save_path / 'optimizer_{}'.format(fixed_str)))
Exemplo n.º 30
0
            # Backward the averaged gradient
            loss /= p['nAveGrad']
            loss.backward()
            aveGrad += 1

            # Update the weights once in p['nAveGrad'] forward passes
            if aveGrad % p['nAveGrad'] == 0:
                writer.add_scalar('data/total_loss_iter', loss.item(), ii + num_img_tr * epoch)
                optimizer.step()
                optimizer.zero_grad()
                aveGrad = 0

            if ii % num_img_tr / 20 == 0:
                grid_image = make_grid(inputs[:3].clone().cpu().data, 3, normalize=True)
                writer.add_image('image', grid_image)
                grid_image = make_grid(utils.decode_seg_map_sequence(torch.max(output[:3], 1)[1].detach().cpu().numpy()), 3, normalize=False,
                                       range=(0, 255))
                writer.add_image('Predicted label', grid_image)
                grid_image = make_grid(utils.decode_seg_map_sequence(torch.squeeze(gts[:3], 1).detach().cpu().numpy()), 3, normalize=False, range=(0, 255))
                writer.add_image('Groundtruth label', grid_image)

        # Save the model
        if (epoch % snapshot) == snapshot - 1:
            torch.save(net.state_dict(), os.path.join(save_dir, 'models', modelName + '_epoch-' + str(epoch) + '.pth'))
            print("Save model at {}\n".format(os.path.join(save_dir, 'models', modelName + '_epoch-' + str(epoch) + '.pth')))

        # One testing epoch
        if useTest and epoch % nTestInterval == (nTestInterval - 1):
            net.eval()
            for ii, sample_batched in enumerate(testloader):
Exemplo n.º 31
0
        opt.step()

        if step % attention_period == 0:
            attention = []
            inp = create_sequence(seq_len=10, batch_size=1, cuda=cuda)
            for i in range(inp.size(0)):
                ntm.send(inp[i])
                attention.append(ntm.write_head.attention)
            for i in range(inp.size(0) - 1):
                x = ntm.receive(input_zero)
                attention.append(ntm.read_head.attention)

            attention = torch.stack(attention, dim=0)
            # remove the batch axis and set the sequence on the x axis
            attention = attention.squeeze().transpose(0, 1)
            writer.add_image('Attention', attention, step)

        if nb_samples > 2000000 or (out.size(0) > 15 and meanloss < 5e-4):
            break

    # record the accuracy on different sequence lengths
    # we want to average over 20 sequences
    # but the model has a predefined batch size
    # that's a code problem but we'll use a for loop instead

    nb_batches = 20 // batch_size
    nb_test_samples = nb_batches * batch_size
    print(f"Number of test batches: {nb_batches} "
          f"and test samples : {nb_test_samples}")
    for seq_len in (list(range(1, 20)) + list(range(20, 101, 10))):
        loss = 0
        batch_v = batch_v.to(device)
        gen_output_v = net_gener(gen_input_v)

        # train discriminator
        dis_optimizer.zero_grad()
        dis_output_true_v = net_discr(batch_v)
        dis_output_fake_v = net_discr(gen_output_v.detach())
        dis_loss = objective(dis_output_true_v, true_labels_v) + objective(dis_output_fake_v, fake_labels_v)
        dis_loss.backward()
        dis_optimizer.step()
        dis_losses.append(dis_loss.item())

        # train generator
        gen_optimizer.zero_grad()
        dis_output_v = net_discr(gen_output_v)
        gen_loss_v = objective(dis_output_v, true_labels_v)
        gen_loss_v.backward()
        gen_optimizer.step()
        gen_losses.append(gen_loss_v.item())

        iter_no += 1
        if iter_no % REPORT_EVERY_ITER == 0:
            log.info("Iter %d: gen_loss=%.3e, dis_loss=%.3e", iter_no, np.mean(gen_losses), np.mean(dis_losses))
            writer.add_scalar("gen_loss", np.mean(gen_losses), iter_no)
            writer.add_scalar("dis_loss", np.mean(dis_losses), iter_no)
            gen_losses = []
            dis_losses = []
        if iter_no % SAVE_IMAGE_EVERY_ITER == 0:
            writer.add_image("fake", vutils.make_grid(gen_output_v.data[:64]), iter_no)
            writer.add_image("real", vutils.make_grid(batch_v.data[:64]), iter_no)
Exemplo n.º 33
0
                writer.add_scalar('val/mse', mse_sum/len(val_set), iteration)
                writer.add_scalar('val/psnr', psnr_sum / len(val_set), iteration)
                writer.add_scalar('val/rgb_error', rgb_loss_sum / len(val_set), iteration)
                writer.add_scalar('val/mean_error', mean_loss_sum / len(val_set), iteration)
                writer.add_scalar('val/perceptual_error', per_loss_sum / len(val_set), iteration)
                writer.add_scalar('val/color_error', col_loss_sum / len(val_set), iteration)

                # save image results
                if epoch % opt.val_img_interval == 0 and epoch != 0:
                    val_images = torch.stack(val_images)
                    val_images = torch.chunk(val_images, val_images.size(0) // (n_val_images * 5))
                    val_save_bar = tqdm(val_images, desc='[Saving results]')
                    for index, image in enumerate(val_save_bar):
                        image = tvutils.make_grid(image, nrow=n_val_images, padding=5)
                        out_path = 'val/target_fake_tex_disc_f-wav_t-wav_' + str(index)
                        writer.add_image('val/target_fake_crop_low_high_' + str(index), image, iteration)

    # save model parameters
    if opt.saving and epoch % opt.save_model_interval == 0 and epoch != 0:
        path = os.path.join(save_path, 'checkpoints', 'iteration_{}.tar'.format(iteration))
        if not os.path.exists(os.path.dirname(path)):
            os.makedirs(os.path.dirname(path))
        state_dict = {
            'epoch': epoch,
            'iteration': iteration,
            'model_g_state_dict': model_g.state_dict(),
            'models_d_state_dict': model_d.state_dict(),
            'optimizer_g_state_dict': optimizer_g.state_dict(),
            'optimizer_d_state_dict': optimizer_d.state_dict(),
            'scheduler_g_state_dict': scheduler_g.state_dict(),
            'scheduler_d_state_dict': scheduler_d.state_dict(),
Exemplo n.º 34
0
class STGANAgent(object):
    def __init__(self, config):
        self.config = config
        self.logger = logging.getLogger("STGAN")
        self.logger.info("Creating STGAN architecture...")

        self.G = Generator(len(self.config.attrs), self.config.g_conv_dim, self.config.g_layers, self.config.shortcut_layers, use_stu=self.config.use_stu, one_more_conv=self.config.one_more_conv)
        self.D = Discriminator(self.config.image_size, len(self.config.attrs), self.config.d_conv_dim, self.config.d_fc_dim, self.config.d_layers)

        self.data_loader = globals()['{}_loader'.format(self.config.dataset)](
            self.config.data_root, self.config.mode, self.config.attrs,
            self.config.crop_size, self.config.image_size, self.config.batch_size)

        self.current_iteration = 0
        self.cuda = torch.cuda.is_available() & self.config.cuda

        if self.cuda:
            self.device = torch.device("cuda")
            self.logger.info("Operation will be on *****GPU-CUDA***** ")
            print_cuda_statistics()
        else:
            self.device = torch.device("cpu")
            self.logger.info("Operation will be on *****CPU***** ")

        self.writer = SummaryWriter(log_dir=self.config.summary_dir)

    def save_checkpoint(self):
        G_state = {
            'state_dict': self.G.state_dict(),
            'optimizer': self.optimizer_G.state_dict(),
        }
        D_state  = {
            'state_dict': self.D.state_dict(),
            'optimizer': self.optimizer_D.state_dict(),
        }
        G_filename = 'G_{}.pth.tar'.format(self.current_iteration)
        D_filename = 'D_{}.pth.tar'.format(self.current_iteration)
        torch.save(G_state, os.path.join(self.config.checkpoint_dir, G_filename))
        torch.save(D_state, os.path.join(self.config.checkpoint_dir, D_filename))

    def load_checkpoint(self):
        if self.config.checkpoint is None:
            self.G.to(self.device)
            self.D.to(self.device)
            return
        G_filename = 'G_{}.pth.tar'.format(self.config.checkpoint)
        D_filename = 'D_{}.pth.tar'.format(self.config.checkpoint)
        G_checkpoint = torch.load(os.path.join(self.config.checkpoint_dir, G_filename))
        D_checkpoint = torch.load(os.path.join(self.config.checkpoint_dir, D_filename))
        G_to_load = {k.replace('module.', ''): v for k, v in G_checkpoint['state_dict'].items()}
        D_to_load = {k.replace('module.', ''): v for k, v in D_checkpoint['state_dict'].items()}
        self.current_iteration = self.config.checkpoint
        self.G.load_state_dict(G_to_load)
        self.D.load_state_dict(D_to_load)
        self.G.to(self.device)
        self.D.to(self.device)
        if self.config.mode == 'train':
            self.optimizer_G.load_state_dict(G_checkpoint['optimizer'])
            self.optimizer_D.load_state_dict(D_checkpoint['optimizer'])

    def denorm(self, x):
        out = (x + 1) / 2
        return out.clamp_(0, 1)

    def create_labels(self, c_org, selected_attrs=None):
        """Generate target domain labels for debugging and testing."""
        # get hair color indices
        hair_color_indices = []
        for i, attr_name in enumerate(selected_attrs):
            if attr_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
                hair_color_indices.append(i)

        c_trg_list = []
        for i in range(len(selected_attrs)):
            c_trg = c_org.clone()
            if i in hair_color_indices:  # set one hair color to 1 and the rest to 0
                c_trg[:, i] = 1
                for j in hair_color_indices:
                    if j != i:
                        c_trg[:, j] = 0
            else:
                c_trg[:, i] = (c_trg[:, i] == 0)  # reverse attribute value

            c_trg_list.append(c_trg.to(self.device))
        return c_trg_list

    def classification_loss(self, logit, target):
        """Compute binary cross entropy loss."""
        return F.binary_cross_entropy_with_logits(logit, target, reduction='sum') / logit.size(0)

    def gradient_penalty(self, y, x):
        """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
        weight = torch.ones(y.size()).to(self.device)
        dydx = torch.autograd.grad(outputs=y,
                                   inputs=x,
                                   grad_outputs=weight,
                                   retain_graph=True,
                                   create_graph=True,
                                   only_inputs=True)[0]

        dydx = dydx.view(dydx.size(0), -1)
        dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
        return torch.mean((dydx_l2norm-1)**2)

    def run(self):
        assert self.config.mode in ['train', 'test']
        try:
            if self.config.mode == 'train':
                self.train()
            else:
                self.test()
        except KeyboardInterrupt:
            self.logger.info('You have entered CTRL+C.. Wait to finalize')
        except Exception as e:
            log_file = open(os.path.join(self.config.log_dir, 'exp_error.log'), 'w+')
            traceback.print_exc(file=log_file)
        finally:
            self.finalize()


    def train(self):
        self.optimizer_G = optim.Adam(self.G.parameters(), self.config.g_lr, [self.config.beta1, self.config.beta2])
        self.optimizer_D = optim.Adam(self.D.parameters(), self.config.d_lr, [self.config.beta1, self.config.beta2])
        self.lr_scheduler_G = optim.lr_scheduler.StepLR(self.optimizer_G, step_size=self.config.lr_decay_iters, gamma=0.1)
        self.lr_scheduler_D = optim.lr_scheduler.StepLR(self.optimizer_D, step_size=self.config.lr_decay_iters, gamma=0.1)

        self.load_checkpoint()
        if self.cuda and self.config.ngpu > 1:
            self.G = nn.DataParallel(self.G, device_ids=list(range(self.config.ngpu)))
            self.D = nn.DataParallel(self.D, device_ids=list(range(self.config.ngpu)))

        val_iter = iter(self.data_loader.val_loader)
        x_sample, c_org_sample = next(val_iter)
        x_sample = x_sample.to(self.device)
        c_sample_list = self.create_labels(c_org_sample, self.config.attrs)
        c_sample_list.insert(0, c_org_sample)  # reconstruction

        self.g_lr = self.lr_scheduler_G.get_lr()[0]
        self.d_lr = self.lr_scheduler_D.get_lr()[0]

        data_iter = iter(self.data_loader.train_loader)
        start_time = time.time()
        for i in range(self.current_iteration, self.config.max_iters):
            self.G.train()
            self.D.train()
            # =================================================================================== #
            #                             1. Preprocess input data                                #
            # =================================================================================== #

            # fetch real images and labels
            try:
                x_real, label_org = next(data_iter)
            except:
                data_iter = iter(self.data_loader.train_loader)
                x_real, label_org = next(data_iter)

            # generate target domain labels randomly
            rand_idx = torch.randperm(label_org.size(0))
            label_trg = label_org[rand_idx]

            c_org = label_org.clone()
            c_trg = label_trg.clone()

            x_real = x_real.to(self.device)         # input images
            c_org = c_org.to(self.device)           # original domain labels
            c_trg = c_trg.to(self.device)           # target domain labels
            label_org = label_org.to(self.device)   # labels for computing classification loss
            label_trg = label_trg.to(self.device)   # labels for computing classification loss

            # =================================================================================== #
            #                             2. Train the discriminator                              #
            # =================================================================================== #

            # compute loss with real images
            out_src, out_cls = self.D(x_real)
            d_loss_real = - torch.mean(out_src)
            d_loss_cls = self.classification_loss(out_cls, label_org)

            # compute loss with fake images
            attr_diff = c_trg - c_org
            attr_diff = attr_diff * torch.rand_like(attr_diff) * (2 * self.config.thres_int)
            x_fake = self.G(x_real, attr_diff)
            out_src, out_cls = self.D(x_fake.detach())
            d_loss_fake = torch.mean(out_src)

            # compute loss for gradient penalty
            alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
            x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)
            out_src, _ = self.D(x_hat)
            d_loss_gp = self.gradient_penalty(out_src, x_hat)

            # backward and optimize
            d_loss_adv = d_loss_real + d_loss_fake + self.config.lambda_gp * d_loss_gp
            d_loss = d_loss_adv + self.config.lambda1 * d_loss_cls
            self.optimizer_D.zero_grad()
            d_loss.backward(retain_graph=True)
            self.optimizer_D.step()

            # summarize
            scalars = {}
            scalars['D/loss'] = d_loss.item()
            scalars['D/loss_adv'] = d_loss_adv.item()
            scalars['D/loss_cls'] = d_loss_cls.item()
            scalars['D/loss_real'] = d_loss_real.item()
            scalars['D/loss_fake'] = d_loss_fake.item()
            scalars['D/loss_gp'] = d_loss_gp.item()

            # =================================================================================== #
            #                               3. Train the generator                                #
            # =================================================================================== #

            if (i + 1) % self.config.n_critic == 0:
                # original-to-target domain
                x_fake = self.G(x_real, attr_diff)
                out_src, out_cls = self.D(x_fake)
                g_loss_adv = - torch.mean(out_src)
                g_loss_cls = self.classification_loss(out_cls, label_trg)

                # target-to-original domain
                x_reconst = self.G(x_fake, c_org - c_org)
                g_loss_rec = torch.mean(torch.abs(x_real - x_reconst))

                # backward and optimize
                g_loss = g_loss_adv + self.config.lambda3 * g_loss_rec + self.config.lambda2 * g_loss_cls
                self.optimizer_G.zero_grad()
                g_loss.backward()
                self.optimizer_G.step()

                # summarize
                scalars['G/loss'] = g_loss.item()
                scalars['G/loss_adv'] = g_loss_adv.item()
                scalars['G/loss_cls'] = g_loss_cls.item()
                scalars['G/loss_rec'] = g_loss_rec.item()

            self.current_iteration += 1

            # =================================================================================== #
            #                                 4. Miscellaneous                                    #
            # =================================================================================== #

            if self.current_iteration % self.config.summary_step == 0:
                et = time.time() - start_time
                et = str(datetime.timedelta(seconds=et))[:-7]
                print('Elapsed [{}], Iteration [{}/{}]'.format(et, self.current_iteration, self.config.max_iters))
                for tag, value in scalars.items():
                    self.writer.add_scalar(tag, value, self.current_iteration)

            if self.current_iteration % self.config.sample_step == 0:
                self.G.eval()
                with torch.no_grad():
                    x_sample = x_sample.to(self.device)
                    x_fake_list = [x_sample]
                    for c_trg_sample in c_sample_list:
                        attr_diff = c_trg_sample.to(self.device) - c_org_sample.to(self.device)
                        attr_diff = attr_diff * self.config.thres_int
                        x_fake_list.append(self.G(x_sample, attr_diff.to(self.device)))
                    x_concat = torch.cat(x_fake_list, dim=3)
                    self.writer.add_image('sample', make_grid(self.denorm(x_concat.data.cpu()), nrow=1),
                                          self.current_iteration)
                    save_image(self.denorm(x_concat.data.cpu()),
                               os.path.join(self.config.sample_dir, 'sample_{}.jpg'.format(self.current_iteration)),
                               nrow=1, padding=0)

            if self.current_iteration % self.config.checkpoint_step == 0:
                self.save_checkpoint()

            self.lr_scheduler_G.step()
            self.lr_scheduler_D.step()

    def test(self):
        self.load_checkpoint()
        self.G.to(self.device)

        tqdm_loader = tqdm(self.data_loader.test_loader, total=self.data_loader.test_iterations,
                          desc='Testing at checkpoint {}'.format(self.config.checkpoint))

        self.G.eval()
        with torch.no_grad():
            for i, (x_real, c_org) in enumerate(tqdm_loader):
                x_real = x_real.to(self.device)
                c_trg_list = self.create_labels(c_org, self.config.attrs)

                x_fake_list = [x_real]
                for c_trg in c_trg_list:
                    attr_diff = c_trg - c_org
                    x_fake_list.append(self.G(x_real, attr_diff.to(self.device)))
                x_concat = torch.cat(x_fake_list, dim=3)
                result_path = os.path.join(self.config.result_dir, 'sample_{}.jpg'.format(i + 1))
                save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)

    def finalize(self):
        print('Please wait while finalizing the operation.. Thank you')
        self.writer.export_scalars_to_json(os.path.join(self.config.summary_dir, 'all_scalars.json'))
        self.writer.close()
Exemplo n.º 35
0
Arquivo: utils.py Projeto: hyun78/elsa
class Logger(object):
    """Reference: https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514"""
    def __init__(self, fn, ask=True, local_rank=0, gpu_num=0):
        self.local_rank = local_rank
        self.gpu_num = gpu_num
        if self.local_rank == 0:
            #if not os.path.exists("./logs/"):
            #    os.mkdir("./logs/")
            if not os.path.exists("./logs" + str(gpu_num) + "/"):
                os.mkdir("./logs" + str(gpu_num) + "/")

            logdir = self._make_dir(fn)
            if not os.path.exists(logdir):
                os.mkdir(logdir)

            if len(os.listdir(logdir)) != 0 and ask:
                ans = input(
                    "log_dir is not empty. All data inside log_dir will be deleted. "
                    "Will you proceed [y/N]? ")
                if ans in ['y', 'Y']:
                    shutil.rmtree(logdir)
                else:
                    exit(1)

            self.set_dir(logdir)

    def _make_dir(self, fn):
        today = datetime.today().strftime("%y%m%d")
        logdir = 'logs' + str(self.gpu_num) + '/' + fn
        return logdir

    def set_dir(self, logdir, log_fn='log.txt'):
        self.logdir = logdir
        if not os.path.exists(logdir):
            os.mkdir(logdir)
        self.writer = SummaryWriter(logdir)
        self.log_file = open(os.path.join(logdir, log_fn), 'a')

    def log(self, string):
        if self.local_rank == 0:
            self.log_file.write('[%s] %s' % (datetime.now(), string) + '\n')
            self.log_file.flush()

            print('[%s] %s' % (datetime.now(), string))
            sys.stdout.flush()

    def log_dirname(self, string):
        if self.local_rank == 0:
            self.log_file.write('%s (%s)' % (string, self.logdir) + '\n')
            self.log_file.flush()

            print('%s (%s)' % (string, self.logdir))
            sys.stdout.flush()

    def scalar_summary(self, tag, value, step):
        """Log a scalar variable."""
        if self.local_rank == 0:
            self.writer.add_scalar(tag, value, step)

    def image_summary(self, tag, images, step):
        """Log a list of images."""
        if self.local_rank == 0:
            self.writer.add_image(tag, images, step)

    def histo_summary(self, tag, values, step):
        """Log a histogram of the tensor of values."""
        if self.local_rank == 0:
            self.writer.add_histogram(tag, values, step, bins='auto')
Exemplo n.º 36
0
    args = parser.parse_args()

    if args.resume is None:
        args.output = get_output_folder(args.output, args.env)
    else:
        args.output = args.resume

    bullet = ("Bullet" in args.env)
    if bullet:
        import pybullet
        import pybullet_envs

    if args.env == "Paint":
        from env import CanvasEnv
        env = CanvasEnv()
        writer.add_image('circle.png', cv2.cvtColor(env.target, cv2.COLOR_GRAY2RGB))
    elif args.env == "KukaGym":
        env = KukaGymEnv(renders=False, isDiscrete=True)
    elif args.env == "LTR":
        from osim.env import RunEnv
        env = RunEnv(visualize=False)
    elif args.discrete:        
        env = gym.make(args.env)
        env = env.unwrapped
    else:
        env = NormalizedEnv(gym.make(args.env))

    # input random seed
    if args.seed > 0:
        np.random.seed(args.seed)
        env.seed(args.seed)
Exemplo n.º 37
0
    tb = SummaryWriter('runs', 'mini-imagenet')
    mini = MiniImagenet('../mini-imagenet/',
                        mode='train',
                        n_way=5,
                        k_shot=1,
                        k_query=1,
                        batchsz=1000,
                        resize=168)

    for i, set_ in enumerate(mini):
        # support_x: [k_shot*n_way, 3, 84, 84]
        support_x, support_y, query_x, query_y = set_

        support_x = make_grid(support_x, nrow=2)
        query_x = make_grid(query_x, nrow=2)

        plt.figure(1)
        plt.imshow(support_x.transpose(2, 0).numpy())
        plt.pause(0.5)
        plt.figure(2)
        plt.imshow(query_x.transpose(2, 0).numpy())
        plt.pause(0.5)

        tb.add_image('support_x', support_x)
        tb.add_image('query_x', query_x)

        time.sleep(5)

    tb.close()
            if i > 10:
                break
            opt_dec.zero_grad()
            opt_en.zero_grad()
            img = batch['image'].to(device)
            mask = batch['mask'].to(device)
            pred, loss = model(img, mask)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 100)
            opt_dec.step()
            opt_en.step()
            writer.add_scalar('loss', float(loss), global_step=iterate)
            if iterate % args.display_freq == 0:
                for masked in pred:
                    writer.add_image('{}'.format(masked.size()[2]),
                                     masked,
                                     global_step=iterate)
                writer.add_image('GT', mask, iterate)
                writer.add_image('Image', img, iterate)

            if iterate % 200 == 0:
                if i != 0:
                    torch.save(
                        model.state_dict(),
                        os.path.join(weight_save_dir,
                                     '{}epo_{}step.ckpt'.format(epo, iterate)))
            if iterate % 1000 == 0 and i != 0:
                for file in weight_save_dir:
                    if '00' in file and '000' not in file:
                        os.remove(os.path.join(weight_save_dir, file))
            if i + epo * len(dataloader) % decay_step == 0 and i != 0:
Exemplo n.º 39
0
class Trainer(Solver):
    ''' Handler for complete training progress'''
    def __init__(self, config, paras):
        super(Trainer, self).__init__(config, paras)
        # Logger Settings
        self.logdir = os.path.join(paras.logdir, self.exp_name)
        self.log = SummaryWriter(self.logdir)

        # Training details
        self.apex = config['solver']['apex']
        self.log_step = config['solver']['log_step']
        self.save_step = config['solver']['save_step']
        self.total_steps = config['solver']['total_steps']
        self.learning_rate = float(self.config['optimizer']['learning_rate'])
        self.warmup_proportion = self.config['optimizer']['warmup_proportion']
        self.gradient_accumulation_steps = self.config['optimizer']['gradient_accumulation_steps']
        self.gradient_clipping = self.config['optimizer']['gradient_clipping']
        self.max_keep = config['solver']['max_keep']
        self.reset_train()

        # mkdir
        if not os.path.exists(self.paras.ckpdir): os.makedirs(self.paras.ckpdir)
        if not os.path.exists(self.ckpdir): os.makedirs(self.ckpdir)
        copyfile(self.paras.config, os.path.join(self.ckpdir, self.paras.config.split('/')[-1]))

    def reset_train(self):
        self.model_kept = []
        self.global_step = 1


    def process_data(self, spec):
        """Process training data for the masked acoustic model"""
        with torch.no_grad():
            
            assert(len(spec) == 5), 'dataloader should return (spec_masked, pos_enc, mask_label, attn_mask, spec_stacked)'
            # Unpack and Hack bucket: Bucketing should cause acoustic feature to have shape 1xBxTxD'
            spec_masked = spec[0].squeeze(0)
            pos_enc = spec[1].squeeze(0)
            mask_label = spec[2].squeeze(0)
            attn_mask = spec[3].squeeze(0)
            spec_stacked = spec[4].squeeze(0)

            spec_masked = spec_masked.to(device=self.device)
            pos_enc = torch.FloatTensor(pos_enc).to(device=self.device)
            mask_label = torch.ByteTensor(mask_label).to(device=self.device)
            attn_mask = torch.FloatTensor(attn_mask).to(device=self.device)
            spec_stacked = spec_stacked.to(device=self.device)

        return spec_masked, pos_enc, mask_label, attn_mask, spec_stacked # (x, pos_enc, mask_label, attention_mask. y)


    def exec(self):
        ''' Training Unsupervised End-to-end Mockingjay Model'''
        self.verbose('Training set total ' + str(len(self.dataloader)) + ' batches.')

        pbar = tqdm(total=self.total_steps)
        while self.global_step <= self.total_steps:

            progress = tqdm(self.dataloader, desc="Iteration")

            for step, batch in enumerate(progress):
                try:
                    if self.global_step > self.total_steps: break
                    
                    spec_masked, pos_enc, mask_label, attn_mask, spec_stacked = self.process_data(batch)
                    loss, pred_spec = self.model(spec_masked, pos_enc, mask_label, attn_mask, spec_stacked)
                    
                    # Accumulate Loss
                    if self.gradient_accumulation_steps > 1:
                        loss = loss / self.gradient_accumulation_steps
                    if self.apex:
                        self.optimizer.backward(loss)
                    else:
                        loss.backward()

                    # Update
                    if step % self.gradient_accumulation_steps == 0:
                        if self.apex:
                            # modify learning rate with special warm up BERT uses
                            # if conifg.apex is False, BertAdam is used and handles this automatically
                            lr_this_step = self.learning_rate * self.warmup_linear.get_lr(self.global_step, self.warmup_proportion)
                            for param_group in self.optimizer.param_groups:
                                param_group['lr'] = lr_this_step
                        
                        # Step
                        grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clipping)
                        if math.isnan(grad_norm):
                            self.verbose('Error : grad norm is NaN @ step ' + str(self.global_step))
                        else:
                            self.optimizer.step()
                        self.optimizer.zero_grad()

                    if self.global_step % self.log_step == 0:
                        # Log
                        self.log.add_scalar('lr', self.optimizer.get_lr()[0], self.global_step)
                        self.log.add_scalar('loss', loss.item(), self.global_step)
                        self.log.add_scalar('gradient norm', grad_norm, self.global_step)
                        progress.set_description("Loss %.4f" % loss.item())

                    if self.global_step % self.save_step == 0:
                        self.save_model('mockingjay')
                        mask_spec = self.up_sample_frames(spec_masked[0], return_first=True)
                        pred_spec = self.up_sample_frames(pred_spec[0], return_first=True)
                        true_spec = self.up_sample_frames(spec_stacked[0], return_first=True)
                        mask_spec = plot_spectrogram_to_numpy(mask_spec.data.cpu().numpy())
                        pred_spec = plot_spectrogram_to_numpy(pred_spec.data.cpu().numpy())
                        true_spec = plot_spectrogram_to_numpy(true_spec.data.cpu().numpy())
                        self.log.add_image('mask_spec', mask_spec, self.global_step)
                        self.log.add_image('pred_spec', pred_spec, self.global_step)
                        self.log.add_image('true_spec', true_spec, self.global_step)
                
                except RuntimeError:
                    print('CUDA out of memory at step: ', self.global_step)
                    torch.cuda.empty_cache()
                    self.optimizer.zero_grad()

                pbar.update(1)
                self.global_step += 1
                
        pbar.close()
        self.reset_train()
Exemplo n.º 40
0
            noises.data.copy_(torch.randn(opt.batch_size, opt.nz, 1, 1))
            fake_img = netg(noises)
            fake_output = netd(fake_img)
            error_g = criterion(fake_output, true_labels)

            print('error_g:,', error_g.data[0])
            writer.add_scalar('data/error_g', error_g.data[0], ii)

            error_g.backward()
            optimizer_g.step()

        if (ii + 1) % opt.plot_every == 0:
            fix_fake_imgs = netg(fix_noises)

            fake = fix_fake_imgs[:64] * 0.5 + 0.5
            real = real_img[:64] * 0.5 + 0.5

            writer.add_image('image/fake_Image', fake, ii)
            writer.add_image('image/real_Image', real, ii)

            print('epoch[{}:{}],ii[{}:{}]'.format(epoch, opt.max_epoch, ii, len(dataloader)))

        if (epoch + 1) % opt.decay_every == 0:
            utils.save_image(fix_fake_imgs.data[:64], '%s/%s.png' % (opt.save_path, epoch), normalize=True,
                             range=(-1, 1))
            torch.save(netd.state_dict(), 'checkpoints/netd_%s.pth' % epoch)
            torch.save(netg.state_dict(), 'checkpoints/netg_%s.pth' % epoch)
            optimizer_g = torch.optim.Adam(netg.parameters(), opt.lr1, betas=(opt.beta1, 0.999))
            optimizer_d = torch.optim.Adam(netd.parameters(), opt.lr2, betas=(opt.beta1, 0.999))
Exemplo n.º 41
0
class face_learner(object):
    def __init__(self, conf, inference=False):
        print(conf)
        if conf.use_mobilfacenet:
            self.model = MobileFaceNet(conf.embedding_size).to(conf.device)
            print('MobileFaceNet model generated')
        else:
            self.model = Backbone(conf.net_depth, conf.drop_ratio,
                                  conf.net_mode).to(conf.device)
            print('{}_{} model generated'.format(conf.net_mode,
                                                 conf.net_depth))

        if not inference:
            self.milestones = conf.milestones
            self.loader, self.class_num = get_train_loader(conf)

            self.writer = SummaryWriter(conf.log_path)
            self.step = 0
            self.head = Arcface(embedding_size=conf.embedding_size,
                                classnum=self.class_num).to(conf.device)

            print('two model heads generated')

            paras_only_bn, paras_wo_bn = separate_bn_paras(self.model)

            if conf.use_mobilfacenet:
                self.optimizer = optim.SGD(
                    [{
                        'params': paras_wo_bn[:-1],
                        'weight_decay': 4e-5
                    }, {
                        'params': [paras_wo_bn[-1]] + [self.head.kernel],
                        'weight_decay': 4e-4
                    }, {
                        'params': paras_only_bn
                    }],
                    lr=conf.lr,
                    momentum=conf.momentum)
            else:
                self.optimizer = optim.SGD(
                    [{
                        'params': paras_wo_bn + [self.head.kernel],
                        'weight_decay': 5e-4
                    }, {
                        'params': paras_only_bn
                    }],
                    lr=conf.lr,
                    momentum=conf.momentum)
            print(self.optimizer)
            #             self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=40, verbose=True)

            print('optimizers generated')
            self.board_loss_every = len(self.loader) // 100
            self.evaluate_every = len(self.loader) // 10
            self.save_every = len(self.loader) // 5
            self.agedb_30, self.cfp_fp, self.lfw, self.agedb_30_issame, self.cfp_fp_issame, self.lfw_issame = get_val_data(
                self.loader.dataset.root.parent)
        else:
            self.threshold = conf.threshold

    def save_state(self,
                   conf,
                   accuracy,
                   to_save_folder=False,
                   extra=None,
                   model_only=False):
        if to_save_folder:
            save_path = conf.save_path
        else:
            save_path = conf.model_path
        torch.save(
            self.model.state_dict(),
            save_path / ('model_{}_accuracy:{}_step:{}_{}.pth'.format(
                get_time(), accuracy, self.step, extra)))
        if not model_only:
            torch.save(
                self.head.state_dict(),
                save_path / ('head_{}_accuracy:{}_step:{}_{}.pth'.format(
                    get_time(), accuracy, self.step, extra)))
            torch.save(
                self.optimizer.state_dict(),
                save_path / ('optimizer_{}_accuracy:{}_step:{}_{}.pth'.format(
                    get_time(), accuracy, self.step, extra)))

    def load_state(self,
                   conf,
                   fixed_str,
                   from_save_folder=False,
                   model_only=False):
        if from_save_folder:
            save_path = conf.save_path
        else:
            save_path = conf.model_path

        pretrained_path = save_path / 'model_{}'.format(fixed_str)
        if torch.cuda.is_available():
            device = torch.cuda.current_device()
            pretrained_dict = torch.load(
                pretrained_path,
                map_location=lambda storage, loc: storage.cuda(device))
        else:
            pretrained_dict = torch.load(
                pretrained_path, map_location=lambda storage, loc: storage)
        self.model.load_state_dict(pretrained_dict)
        if not model_only:
            self.head.load_state_dict(
                torch.load(save_path / 'head_{}'.format(fixed_str)))
            self.optimizer.load_state_dict(
                torch.load(save_path / 'optimizer_{}'.format(fixed_str)))

    def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor):
        self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy,
                               self.step)
        self.writer.add_scalar('{}_best_threshold'.format(db_name),
                               best_threshold, self.step)
        self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor,
                              self.step)
#         self.writer.add_scalar('{}_val:true accept ratio'.format(db_name), val, self.step)
#         self.writer.add_scalar('{}_val_std'.format(db_name), val_std, self.step)
#         self.writer.add_scalar('{}_far:False Acceptance Ratio'.format(db_name), far, self.step)

    def evaluate(self, conf, carray, issame, nrof_folds=5, tta=False):
        self.model.eval()
        idx = 0
        embeddings = np.zeros([len(carray), conf.embedding_size])
        with torch.no_grad():
            while idx + conf.batch_size <= len(carray):
                batch = torch.tensor(carray[idx:idx + conf.batch_size])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.to(conf.device)) + self.model(
                        fliped.to(conf.device))
                    embeddings[idx:idx + conf.batch_size] = l2_norm(emb_batch)
                else:
                    embeddings[idx:idx + conf.batch_size] = self.model(
                        batch.to(conf.device)).cpu()
                idx += conf.batch_size
                print('{} / {}'.format(idx, len(carray)))
            if idx < len(carray):
                batch = torch.tensor(carray[idx:])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.to(conf.device)) + self.model(
                        fliped.to(conf.device))
                    embeddings[idx:] = l2_norm(emb_batch)
                else:
                    embeddings[idx:] = self.model(batch.to(conf.device)).cpu()
        tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame,
                                                       nrof_folds)
        buf = gen_plot(fpr, tpr)
        roc_curve = Image.open(buf)
        roc_curve_tensor = trans.ToTensor()(roc_curve)
        return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor

    def find_lr(self,
                conf,
                init_value=1e-8,
                final_value=10.,
                beta=0.98,
                bloding_scale=3.,
                num=None):
        if not num:
            num = len(self.loader)
        mult = (final_value / init_value)**(1 / num)
        lr = init_value
        for params in self.optimizer.param_groups:
            params['lr'] = lr
        self.model.train()
        avg_loss = 0.
        best_loss = 0.
        batch_num = 0
        losses = []
        log_lrs = []
        for i, (imgs, labels) in tqdm(enumerate(self.loader), total=num):

            imgs = imgs.to(conf.device)
            labels = labels.to(conf.device)
            batch_num += 1

            self.optimizer.zero_grad()

            embeddings = self.model(imgs)
            thetas = self.head(embeddings, labels)
            loss = conf.ce_loss(thetas, labels)

            #Compute the smoothed loss
            avg_loss = beta * avg_loss + (1 - beta) * loss.item()
            self.writer.add_scalar('avg_loss', avg_loss, batch_num)
            smoothed_loss = avg_loss / (1 - beta**batch_num)
            self.writer.add_scalar('smoothed_loss', smoothed_loss, batch_num)
            #Stop if the loss is exploding
            if batch_num > 1 and smoothed_loss > bloding_scale * best_loss:
                print('exited with best_loss at {}'.format(best_loss))
                plt.plot(log_lrs[10:-5], losses[10:-5])
                return log_lrs, losses
            #Record the best loss
            if smoothed_loss < best_loss or batch_num == 1:
                best_loss = smoothed_loss
            #Store the values
            losses.append(smoothed_loss)
            log_lrs.append(math.log10(lr))
            self.writer.add_scalar('log_lr', math.log10(lr), batch_num)
            #Do the SGD step
            #Update the lr for the next step

            loss.backward()
            self.optimizer.step()

            lr *= mult
            for params in self.optimizer.param_groups:
                params['lr'] = lr
            if batch_num > num:
                plt.plot(log_lrs[10:-5], losses[10:-5])
                return log_lrs, losses

    def train(self, conf, epochs):
        self.model.train()
        running_loss = 0.
        for e in range(epochs):
            print('epoch {} started'.format(e))
            if e == self.milestones[0]:
                self.schedule_lr()
            if e == self.milestones[1]:
                self.schedule_lr()
            if e == self.milestones[2]:
                self.schedule_lr()
            for imgs, labels in tqdm(iter(self.loader)):
                imgs = imgs.to(conf.device)
                labels = labels.to(conf.device)
                self.optimizer.zero_grad()
                embeddings = self.model(imgs)
                thetas = self.head(embeddings, labels)
                loss = conf.ce_loss(thetas, labels)
                loss.backward()
                running_loss += loss.item()
                self.optimizer.step()

                if self.step % self.board_loss_every == 0 and self.step != 0:
                    loss_board = running_loss / self.board_loss_every
                    self.writer.add_scalar('train_loss', loss_board, self.step)
                    running_loss = 0.

                if self.step % self.evaluate_every == 0 and self.step != 0:
                    accuracy, best_threshold, roc_curve_tensor = self.evaluate(
                        conf, self.agedb_30, self.agedb_30_issame)
                    self.board_val('agedb_30', accuracy, best_threshold,
                                   roc_curve_tensor)
                    accuracy, best_threshold, roc_curve_tensor = self.evaluate(
                        conf, self.lfw, self.lfw_issame)
                    self.board_val('lfw', accuracy, best_threshold,
                                   roc_curve_tensor)
                    accuracy, best_threshold, roc_curve_tensor = self.evaluate(
                        conf, self.cfp_fp, self.cfp_fp_issame)
                    self.board_val('cfp_fp', accuracy, best_threshold,
                                   roc_curve_tensor)
                    self.model.train()
                if self.step % self.save_every == 0 and self.step != 0:
                    self.save_state(conf, accuracy)

                self.step += 1

        self.save_state(conf, accuracy, to_save_folder=True, extra='final')

    def schedule_lr(self):
        for params in self.optimizer.param_groups:
            params['lr'] /= 10
        print(self.optimizer)

    def infer(self, conf, faces, target_embs, tta=False):
        '''
        faces : list of PIL Image
        target_embs : [n, 512] computed embeddings of faces in facebank
        names : recorded names of faces in facebank
        tta : test time augmentation (hfilp, that's all)
        '''
        embs = []
        for img in faces:
            if tta:
                mirror = trans.functional.hflip(img)
                emb = self.model(
                    conf.test_transform(img).to(conf.device).unsqueeze(0))
                emb_mirror = self.model(
                    conf.test_transform(mirror).to(conf.device).unsqueeze(0))
                embs.append(l2_norm(emb + emb_mirror))
            else:
                embs.append(
                    self.model(
                        conf.test_transform(img).to(conf.device).unsqueeze(0)))
        source_embs = torch.cat(embs)

        diff = source_embs.unsqueeze(-1) - target_embs.transpose(
            1, 0).unsqueeze(0)
        dist = torch.sum(torch.pow(diff, 2), dim=1)
        minimum, min_idx = torch.min(dist, dim=1)
        min_idx[minimum > self.threshold] = -1  # if no match, set idx to -1
        return min_idx, minimum
Exemplo n.º 42
0
    args = parser.parse_args()

    if args.resume is None:
        args.output = get_output_folder(args.output, args.env)
    else:
        args.output = args.resume

    bullet = ("Bullet" in args.env)
    if bullet:
        import pybullet
        import pybullet_envs

    if args.env == "Paint":
        from env import CanvasEnv
        env = CanvasEnv()
        writer.add_image('circle.png', env.target)
    elif args.env == "KukaGym":
        env = KukaGymEnv(renders=False, isDiscrete=True)
    elif args.env == "LTR":
        from osim.env import RunEnv
        env = RunEnv(visualize=False)
    elif args.discrete:        
        env = gym.make(args.env)
        env = env.unwrapped
    else:
        env = NormalizedEnv(gym.make(args.env))

    # input random seed
    if args.seed > 0:
        np.random.seed(args.seed)
        env.seed(args.seed)
Exemplo n.º 43
0
class ModelTrainerK2K:
    """
    Model trainer for K space to K space learning.
    All loss is calculated on the K space domain only.
    Conversion to image domain is for viewing only.
    """
    def __init__(self, args, model, optimizer, train_loader, val_loader,
                 post_processing, k_loss, metrics=None, scheduler=None):

        multiprocessing.set_start_method(method='spawn')
        self.logger = get_logger(name=__name__, save_file=args.log_path / args.run_name)

        # Checking whether inputs are correct.
        assert isinstance(model, nn.Module), '`model` must be a Pytorch Module.'
        assert isinstance(optimizer, optim.Optimizer), '`optimizer` must be a Pytorch Optimizer.'
        assert isinstance(train_loader, DataLoader) and isinstance(val_loader, DataLoader), \
            '`train_loader` and `val_loader` must be Pytorch DataLoader objects.'

        # I think this would be best practice.
        assert isinstance(post_processing, nn.Module), '`post_processing_func` must be a Pytorch Module.'

        # This is not a mistake. Pytorch implements loss functions as modules.
        assert isinstance(k_loss, nn.Module), '`k_loss` must be a callable Pytorch Module.'

        if metrics is not None:
            assert isinstance(metrics, Iterable), '`metrics` must be an iterable, preferably a list or tuple.'
            for metric in metrics:
                assert callable(metric), 'All metrics must be callable functions.'

        if scheduler is not None:
            if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                self.metric_scheduler = True
            elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
                self.metric_scheduler = False
            else:
                raise TypeError('`scheduler` must be a Pytorch Learning Rate Scheduler.')

        self.model = model
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.post_processing_func = post_processing
        self.k_loss_func = k_loss
        self.metrics = metrics
        self.scheduler = scheduler

        self.verbose = args.verbose
        self.num_epochs = args.num_epochs
        self.writer = SummaryWriter(logdir=str(args.log_path))

        # Display interval of 0 means no display of validation images on TensorBoard.
        if args.max_images <= 0:
            self.display_interval = 0
        else:
            self.display_interval = int(len(self.val_loader.dataset) // (args.max_images * args.batch_size))

        # Writing model graph to TensorBoard. Results might not be very good.
        # if args.add_graph:
        #     num_chans = 30 if args.challenge == 'multicoil' else 2
        #     example_inputs = torch.ones(size=(1, num_chans, 640, 328), device=args.device)
        #     self.writer.add_graph(model=model, input_to_model=example_inputs)
        #     del example_inputs  # Remove unnecessary tensor taking up memory.

        self.checkpointer = CheckpointManager(
            model=self.model, optimizer=self.optimizer, mode='min', save_best_only=args.save_best_only,
            ckpt_dir=args.ckpt_path, max_to_keep=args.max_to_keep)

        # loading from checkpoint if specified.
        if vars(args).get('prev_model_ckpt'):
            self.checkpointer.load(load_dir=args.prev_model_ckpt, load_optimizer=False)

    def train_model(self):
        self.logger.info('Beginning Training Loop.')
        tic_tic = time()

        for epoch in range(1, self.num_epochs + 1):  # 1 based indexing
            # Training
            tic = time()
            train_epoch_loss, train_epoch_metrics = self._train_epoch(epoch=epoch)
            toc = int(time() - tic)
            self._log_epoch_outputs(epoch=epoch, epoch_loss=train_epoch_loss,
                                    epoch_metrics=train_epoch_metrics, elapsed_secs=toc, training=True)

            # Validation
            tic = time()
            val_epoch_loss, val_epoch_metrics = self._val_epoch(epoch=epoch)
            toc = int(time() - tic)
            self._log_epoch_outputs(epoch=epoch, epoch_loss=val_epoch_loss,
                                    epoch_metrics=val_epoch_metrics, elapsed_secs=toc, training=False)

            self.checkpointer.save(metric=val_epoch_loss, verbose=True)

            if self.scheduler is not None:
                if self.metric_scheduler:  # If the scheduler is a metric based scheduler, include metrics.
                    self.scheduler.step(metrics=val_epoch_loss)
                else:
                    self.scheduler.step()

        # Finishing Training Loop
        self.writer.close()  # Flushes remaining data to TensorBoard.
        toc_toc = int(time() - tic_tic)
        self.logger.info(f'Finishing Training Loop. Total elapsed time: '
                         f'{toc_toc // 3600} hr {(toc_toc // 60) % 60} min {toc_toc % 60} sec.')

    def _train_epoch(self, epoch):
        self.model.train()
        torch.autograd.set_grad_enabled(True)

        epoch_loss_lst = list()  # Appending values to list due to numerical underflow.
        epoch_metrics_lst = [list() for _ in self.metrics] if self.metrics else None

        # labels are fully sampled coil-wise images, not rss or esc.
        for step, (inputs, kspace_targets, extra_params) in enumerate(self.train_loader, start=1):
            step_loss, kspace_recons = self._train_step(inputs, kspace_targets, extra_params)

            # Gradients are not calculated so as to boost speed and remove weird errors.
            with torch.no_grad():  # Update epoch loss and metrics
                epoch_loss_lst.append(step_loss.item())  # Perhaps not elegant, but underflow makes this necessary.

                # The step functions here have all necessary conditionals internally.
                # There is no need to externally specify whether to use them or not.
                step_metrics = self._get_step_metrics(kspace_recons, kspace_targets, epoch_metrics_lst)
                self._log_step_outputs(epoch, step, step_loss, step_metrics, training=True)

        epoch_loss, epoch_metrics = self._get_epoch_outputs(epoch, epoch_loss_lst, epoch_metrics_lst, training=True)
        return epoch_loss, epoch_metrics

    def _train_step(self, inputs, kspace_targets, extra_params):
        self.optimizer.zero_grad()
        outputs = self.model(inputs)
        kspace_recons = self.post_processing_func(outputs, kspace_targets, extra_params)
        step_loss = self.k_loss_func(kspace_recons, kspace_targets)
        step_loss.backward()
        self.optimizer.step()
        return step_loss, kspace_recons

    def _val_epoch(self, epoch):
        self.model.eval()
        torch.autograd.set_grad_enabled(False)

        epoch_loss_lst = list()
        epoch_metrics_lst = [list() for _ in self.metrics] if self.metrics else None

        for step, (inputs, kspace_targets, extra_params) in enumerate(self.val_loader, start=1):
            step_loss, kspace_recons = self._val_step(inputs, kspace_targets, extra_params)

            epoch_loss_lst.append(step_loss.item())
            # Step functions have internalized conditional statements deciding whether to execute or not.
            step_metrics = self._get_step_metrics(kspace_recons, kspace_targets, epoch_metrics_lst)
            self._log_step_outputs(epoch, step, step_loss, step_metrics, training=False)

            # Save images to TensorBoard.
            # Condition ensures that self.display_interval != 0 and that the step is right for display.
            if self.display_interval and (step % self.display_interval == 0):
                # Terrible coding. Unreadable by outsiders. Change later.
                grids = visualize_from_kspace(kspace_recons, kspace_targets, smoothing_factor=8)
                self.writer.add_image(f'k-space_Recons/{step}', grids[0], epoch, dataformats='HW')
                self.writer.add_image(f'k-space_Targets/{step}', grids[1], epoch, dataformats='HW')
                self.writer.add_image(f'Image_Recons/{step}', grids[2], epoch, dataformats='HW')
                self.writer.add_image(f'Image_Targets/{step}', grids[3], epoch, dataformats='HW')
                self.writer.add_image(f'Image_Deltas/{step}', grids[4], epoch, dataformats='HW')

        epoch_loss, epoch_metrics = self._get_epoch_outputs(epoch, epoch_loss_lst, epoch_metrics_lst, training=False)
        return epoch_loss, epoch_metrics

    def _val_step(self, inputs, kspace_targets, extra_params):
        """
        All extra parameters are to be placed in extra_params.
        This makes the system more flexible.
        """
        outputs = self.model(inputs)
        kspace_recons = self.post_processing_func(outputs, kspace_targets, extra_params)
        step_loss = self.k_loss_func(kspace_recons, kspace_targets)
        return step_loss, kspace_recons

    def _get_step_metrics(self, kspace_recons, kspace_targets, epoch_metrics_lst):
        if self.metrics is not None:
            step_metrics = [metric(kspace_recons, kspace_targets) for metric in self.metrics]
            for step_metric, epoch_metric_lst in zip(step_metrics, epoch_metrics_lst):
                epoch_metric_lst.append(step_metric.item())
            return step_metrics
        return None  # Explicitly return None for step_metrics if self.metrics is None. Not necessary but more readable.

    def _get_epoch_outputs(self, epoch, epoch_loss_lst, epoch_metrics_lst, training=True):
        mode = 'Training' if training else 'Validation'
        num_slices = len(self.train_loader.dataset) if training else len(self.val_loader.dataset)

        # Checking for nan values.
        num_nans = np.isnan(epoch_loss_lst).sum()
        if num_nans > 0:
            self.logger.warning(f'Epoch {epoch} {mode}: {num_nans} NaN values present in {num_slices} slices')

        epoch_loss = float(np.nanmean(epoch_loss_lst))  # Remove nan values just in case.
        epoch_metrics = [float(np.nanmean(epoch_metric_lst)) for epoch_metric_lst in
                         epoch_metrics_lst] if self.metrics else None

        return epoch_loss, epoch_metrics

    def _log_step_outputs(self, epoch, step, step_loss, step_metrics, training=True):
        if self.verbose:
            mode = 'Training' if training else 'Validation'
            self.logger.info(f'Epoch {epoch:03d} Step {step:03d} {mode} loss: {step_loss.item():.4e}')
            if self.metrics:
                for idx, step_metric in enumerate(step_metrics):
                    self.logger.info(
                        f'Epoch {epoch:03d} Step {step:03d}: {mode} metric {idx}: {step_metric.item():.4e}')

    def _log_epoch_outputs(self, epoch, epoch_loss, epoch_metrics, elapsed_secs, training=True):
        mode = 'Training' if training else 'Validation'
        self.logger.info(
            f'Epoch {epoch:03d} {mode}. loss: {epoch_loss:.4e}, Time: {elapsed_secs // 60} min {elapsed_secs % 60} sec')
        self.writer.add_scalar(f'{mode}_epoch_loss', scalar_value=epoch_loss, global_step=epoch)
        if isinstance(epoch_metrics, list):  # The metrics being returned are either 'None' or a list of values.
            for idx, epoch_metric in enumerate(epoch_metrics, start=1):
                self.logger.info(f'Epoch {epoch:03d} {mode}. Metric {idx}: {epoch_metric}')
                self.writer.add_scalar(f'{mode}_epoch_metric_{idx}', scalar_value=epoch_metric, global_step=epoch)