예제 #1
0
    def __init__(self, config, outdir, modeldir, data_path, sketch_path,
                 ss_path):

        self.train_config = config["train"]
        self.data_config = config["dataset"]
        model_config = config["model"]
        self.loss_config = config["loss"]

        self.outdir = outdir
        self.modeldir = modeldir

        self.dataset = IllustDataset(
            data_path, sketch_path, ss_path, self.data_config["line_method"],
            self.data_config["extension"], self.data_config["train_size"],
            self.data_config["valid_size"], self.data_config["color_space"],
            self.data_config["line_space"])
        print(self.dataset)

        gen = Generator(model_config["generator"]["in_ch"],
                        base=model_config["generator"]["base"],
                        num_layers=model_config["generator"]["num_layers"],
                        up_layers=model_config["generator"]["up_layers"],
                        guide=model_config["generator"]["guide"],
                        resnext=model_config["generator"]["resnext"],
                        encoder_type=model_config["generator"]["encoder_type"])
        self.gen, self.gen_opt = self._setting_model_optim(
            gen, model_config["generator"])
        self.guide = model_config["generator"]["guide"]

        dis = Discriminator(model_config["discriminator"]["in_ch"],
                            model_config["discriminator"]["multi"],
                            base=model_config["discriminator"]["base"],
                            sn=model_config["discriminator"]["sn"],
                            resnext=model_config["discriminator"]["resnext"],
                            patch=model_config["discriminator"]["patch"])
        self.dis, self.dis_opt = self._setting_model_optim(
            dis, model_config["discriminator"])

        self.vgg = Vgg19(requires_grad=False, layer="four")
        self.vgg.cuda()
        self.vgg.eval()

        self.out_filter = GuidedFilter(r=1, eps=1e-2)
        self.out_filter.cuda()

        self.lossfunc = LossCalculator()
        self.visualizer = Visualizer(self.data_config["color_space"])

        self.scheduler_gen = torch.optim.lr_scheduler.ExponentialLR(
            self.gen_opt, self.train_config["gamma"])
        self.scheduler_dis = torch.optim.lr_scheduler.ExponentialLR(
            self.dis_opt, self.train_config["gamma"])
예제 #2
0
def test_step(network, data_loader, device):
    network.eval()

    top1 = AverageMeter()
    top5 = AverageMeter()
    loss_calculator = LossCalculator()

    with torch.no_grad():
        for inputs, targets in data_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = network(inputs)
            loss_calculator.calc_loss(outputs, targets)
            prec1, prec5 = accuracy(outputs, targets, topk=(1, 5))
            top1.update(prec1.item(), inputs.size(0))
            top5.update(prec5.item(), inputs.size(0))

    return top1.avg, top5.avg, loss_calculator.get_loss_log()
예제 #3
0
def train_step(network, train_data_loader, test_data_loader, optimizer, device,
               epoch):
    network.train()
    # set benchmark flag to faster runtime
    torch.backends.cudnn.benchmark = True

    top1 = AverageMeter()
    top5 = AverageMeter()
    loss_calculator = LossCalculator()

    prev_time = datetime.now()

    for iteration, (inputs, targets) in enumerate(train_data_loader):
        inputs, targets = inputs.to(device), targets.to(device)

        outputs = network(inputs)
        loss = loss_calculator.calc_loss(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        prec1, prec5 = accuracy(outputs, targets, topk=(1, 5))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))

    cur_time = datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    time_str = "Time: {:0>2d}:{:0>2d}:{:0>2d}".format(h, m, s)

    train_acc_str = '[Train] Top1: %2.4f, Top5: %2.4f, ' % (top1.avg, top5.avg)
    train_loss_str = 'Loss: %.4f. ' % loss_calculator.get_loss_log()

    test_top1, test_top5, test_loss = test_step(network, test_data_loader,
                                                device)

    test_acc_str = '[Test] Top1: %2.4f, Top5: %2.4f, ' % (test_top1, test_top5)
    test_loss_str = 'Loss: %.4f. ' % test_loss

    print('Epoch %d. ' % epoch + train_acc_str + train_loss_str +
          test_acc_str + test_loss_str + time_str)

    return None
예제 #4
0
def load_network(args, device):
    network = StackedHourglass(num_stack=args.num_stack,
                               in_ch=args.hourglass_inch,
                               out_ch=args.num_cls + 4,
                               increase_ch=args.increase_ch,
                               activation=args.activation,
                               pool=args.pool,
                               neck_activation=args.neck_activation,
                               neck_pool=args.neck_pool).to(device)

    if len(args.gpu_no) > 1 and args.train_flag:
        network = torch.nn.parallel.DistributedDataParallel(
            network, device_ids=[device])

    optimizer, scheduler, loss_calculator = None, None, None
    if args.train_flag:
        optimizer, scheduler = get_optimizer(network=network,
                                             lr=args.lr,
                                             lr_milestone=args.lr_milestone,
                                             lr_gamma=args.lr_gamma)

        loss_calculator = LossCalculator(hm_weight=args.hm_weight,
                                         offset_weight=args.offset_weight,
                                         size_weight=args.size_weight,
                                         focal_alpha=args.focal_alpha,
                                         focal_beta=args.focal_beta).to(device)

    if args.model_load:
        check_point = torch.load(args.model_load, map_location=device)
        network.load_state_dict(check_point['state_dict'])
        print('%s: Weights are loaded from %s' %
              (time.ctime(), args.model_load))

        if args.train_flag:
            optimizer.load_state_dict(check_point['optimizer'])
            loss_calculator.log = check_point['loss_log']
            if scheduler is not None:
                scheduler.load_state_dict(check_point['scheduler'])

    return network, optimizer, scheduler, loss_calculator
예제 #5
0
    def __init__(self, n_classes=1, n_base_units=32, class_weights=None):
        super().__init__()
        self.n_classes = n_classes
        with self.init_scope():
            self.conv_bn = ConvBN(3, n_base_units, 2)
            self.conv_ds_2 = ConvDW(n_base_units, n_base_units * 2, 1)
            self.conv_ds_3 = ConvDW(n_base_units * 2, n_base_units * 4, 2)
            self.conv_ds_4 = ConvDW(n_base_units * 4, n_base_units * 4, 1)
            self.conv_ds_5 = ConvDW(n_base_units * 4, n_base_units * 8, 2)
            self.conv_ds_6 = ConvDW(n_base_units * 8, n_base_units * 8, 1)
            self.conv_ds_7 = ConvDW(n_base_units * 8, n_base_units * 16, 2)

            self.conv_ds_8 = ConvDW(n_base_units * 16, n_base_units * 16, 1)
            self.conv_ds_9 = ConvDW(n_base_units * 16, n_base_units * 16, 1)
            self.conv_ds_10 = ConvDW(n_base_units * 16, n_base_units * 16, 1)

            self.conv_ds_11 = ConvDW(n_base_units * 16, n_base_units * 32, 2)
            self.conv_ds_12 = ConvDW(n_base_units * 32, 4 + n_classes, 1)

        self.loss_calc = LossCalculator(n_classes,
                                        weight_noobj=0.2,
                                        class_weights=class_weights)
예제 #6
0
class Trainer:
    def __init__(self, config, outdir, modeldir, data_path, sketch_path,
                 ss_path):

        self.train_config = config["train"]
        self.data_config = config["dataset"]
        model_config = config["model"]
        self.loss_config = config["loss"]

        self.outdir = outdir
        self.modeldir = modeldir

        self.dataset = IllustDataset(
            data_path, sketch_path, ss_path, self.data_config["line_method"],
            self.data_config["extension"], self.data_config["train_size"],
            self.data_config["valid_size"], self.data_config["color_space"],
            self.data_config["line_space"])
        print(self.dataset)

        gen = Generator(model_config["generator"]["in_ch"],
                        base=model_config["generator"]["base"],
                        num_layers=model_config["generator"]["num_layers"],
                        up_layers=model_config["generator"]["up_layers"],
                        guide=model_config["generator"]["guide"],
                        resnext=model_config["generator"]["resnext"],
                        encoder_type=model_config["generator"]["encoder_type"])
        self.gen, self.gen_opt = self._setting_model_optim(
            gen, model_config["generator"])
        self.guide = model_config["generator"]["guide"]

        dis = Discriminator(model_config["discriminator"]["in_ch"],
                            model_config["discriminator"]["multi"],
                            base=model_config["discriminator"]["base"],
                            sn=model_config["discriminator"]["sn"],
                            resnext=model_config["discriminator"]["resnext"],
                            patch=model_config["discriminator"]["patch"])
        self.dis, self.dis_opt = self._setting_model_optim(
            dis, model_config["discriminator"])

        self.vgg = Vgg19(requires_grad=False, layer="four")
        self.vgg.cuda()
        self.vgg.eval()

        self.out_filter = GuidedFilter(r=1, eps=1e-2)
        self.out_filter.cuda()

        self.lossfunc = LossCalculator()
        self.visualizer = Visualizer(self.data_config["color_space"])

        self.scheduler_gen = torch.optim.lr_scheduler.ExponentialLR(
            self.gen_opt, self.train_config["gamma"])
        self.scheduler_dis = torch.optim.lr_scheduler.ExponentialLR(
            self.dis_opt, self.train_config["gamma"])

    @staticmethod
    def _setting_model_optim(model: nn.Module, config: Dict):
        model.cuda()
        if config["mode"] == "train":
            model.train()
        elif config["mode"] == "eval":
            model.eval()

        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=config["lr"],
                                     betas=(config["b1"], config["b2"]))

        return model, optimizer

    @staticmethod
    def _valid_prepare(dataset, validsize: int) -> List[torch.Tensor]:
        c_val, l_i_val, m_val, l_m_val = dataset.valid(validsize)
        x_val = torch.cat([l_i_val, m_val], dim=1)

        return [x_val, l_i_val, m_val, c_val, l_m_val]

    @staticmethod
    def _build_dict(loss_dict: Dict[str, float], epoch: int,
                    num_epochs: int) -> Dict[str, str]:

        report_dict = {}
        report_dict["epoch"] = f"{epoch}/{num_epochs}"
        for k, v in loss_dict.items():
            report_dict[k] = f"{v:.4f}"

        return report_dict

    def _loss_weight_scheduler(self, iteration: int):
        if iteration > 50000:
            self.loss_config["adv"] = 0.1
            self.loss_config["content"] = 1.0
            self.loss_config["pef"] = 0.0
        else:
            self.loss_config["adv"] = 0.1
            self.loss_config["content"] = 100.0
            self.loss_config["pef"] = 0.01

    def _eval(self, iteration: int, validsize: int,
              v_list: List[torch.Tensor]):

        torch.save(self.gen.state_dict(),
                   f"{self.modeldir}/generator_{iteration}.pt")
        torch.save(self.dis.state_dict(),
                   f"{self.modeldir}/discriminator_{iteration}.pt")

        with torch.no_grad():
            mid = copy.copy(v_list[4])
            if self.guide:
                y, _, _ = self.gen(v_list[0], mid)
            else:
                y = self.gen(v_list[0], mid)

            y = self.out_filter(v_list[1], y)

        self.visualizer(v_list[1:], y, self.outdir, iteration, validsize)

    def _iter(self, data):
        color, line, mask, line_m = data
        color = color.cuda()
        line = line.cuda()
        mask = mask.cuda()
        line_m = line_m.cuda()

        loss = {}

        x = torch.cat([line, mask], dim=1)
        mid = line_m

        if self.guide:
            y, g1, g2 = self.gen(x, mid)
        else:
            y = self.gen(x, mid)

        y = self.out_filter(line, y)

        if self.loss_config["adv"] > 0:
            # discriminator update
            dis_loss = self.loss_config[
                "adv"] * self.lossfunc.adversarial_disloss(
                    self.dis, y.detach(), color)

            self.dis_opt.zero_grad()
            dis_loss.backward()
            self.dis_opt.step()
        else:
            dis_loss = torch.zeros(1).cuda()

        if self.loss_config["gp"] > 0.0:
            color.requires_grad = True
            gp_loss = self.loss_config["gp"] * self.lossfunc.gradient_penalty(
                self.dis, color)

            self.dis_opt.zero_grad()
            gp_loss.backward()
            self.dis_opt.step()

            color.requires_grad = False
        else:
            gp_loss = torch.zeros(1).cuda()

        if self.guide:
            y, g1, g2 = self.gen(x, mid)
        else:
            y = self.gen(x, mid)

        y = self.out_filter(line, y)

        # generator update
        if self.loss_config["adv"] > 0:
            adv_loss, fm_loss = self.lossfunc.adversarial_genloss(
                self.dis, y, color)
            adv_gen_loss = self.loss_config["adv"] * adv_loss
            fm_loss = self.loss_config["fm"] * fm_loss
        else:
            adv_gen_loss = torch.zeros(1).cuda()
            fm_loss = torch.zeros(1).cuda()

        tv_loss = self.loss_config["tv"] * self.lossfunc.total_variation_loss(
            y)
        content_loss = self.loss_config[
            "content"] * self.lossfunc.content_loss(y, color)
        pef_loss = self.loss_config[
            "pef"] * self.lossfunc.positive_enforcing_loss(y)
        perceptual_loss = self.loss_config[
            "perceptual"] * self.lossfunc.perceptual_loss(self.vgg, y, color)

        if self.guide:
            content_loss += self.loss_config[
                "content"] * self.lossfunc.content_loss(g1, color)
            content_loss += self.loss_config[
                "content"] * self.lossfunc.content_loss(g2, color)

        gen_loss = adv_gen_loss + fm_loss + tv_loss + content_loss + pef_loss + perceptual_loss

        self.gen_opt.zero_grad()
        gen_loss.backward()
        self.gen_opt.step()

        loss["loss_adv_dis"] = dis_loss.item()
        loss["loss_adv_gen"] = adv_gen_loss.item()
        loss["loss_fm"] = fm_loss.item()
        loss["loss_tv"] = tv_loss.item()
        loss["loss_content"] = content_loss.item()
        loss["loss_pef"] = pef_loss.item()
        loss["loss_perceptual"] = perceptual_loss.item()
        loss["loss_gp"] = gp_loss.item()

        return loss

    def __call__(self):
        iteration = 0
        v_list = self._valid_prepare(self.dataset,
                                     self.train_config["validsize"])

        for epoch in range(self.train_config["epoch"]):
            dataloader = DataLoader(self.dataset,
                                    batch_size=self.train_config["batchsize"],
                                    shuffle=True,
                                    drop_last=True)

            with tqdm(total=len(self.dataset)) as pbar:
                for index, data in enumerate(dataloader):
                    self._loss_weight_scheduler(iteration)
                    iteration += 1
                    loss_dict = self._iter(data)

                    report_dict = self._build_dict(loss_dict, epoch,
                                                   self.train_config["epoch"])

                    pbar.update(self.train_config["batchsize"])
                    pbar.set_postfix(**report_dict)

                    if iteration % self.train_config["snapshot_interval"] == 1:
                        self._eval(iteration, self.train_config["validsize"],
                                   v_list)

            self.scheduler_dis.step()
            self.scheduler_gen.step()
예제 #7
0
class MobileYOLO(chainer.Chain):

    img_size = 224
    n_grid = 7

    def __init__(self, n_classes=1, n_base_units=32, class_weights=None):
        super().__init__()
        self.n_classes = n_classes
        with self.init_scope():
            self.conv_bn = ConvBN(3, n_base_units, 2)
            self.conv_ds_2 = ConvDW(n_base_units, n_base_units * 2, 1)
            self.conv_ds_3 = ConvDW(n_base_units * 2, n_base_units * 4, 2)
            self.conv_ds_4 = ConvDW(n_base_units * 4, n_base_units * 4, 1)
            self.conv_ds_5 = ConvDW(n_base_units * 4, n_base_units * 8, 2)
            self.conv_ds_6 = ConvDW(n_base_units * 8, n_base_units * 8, 1)
            self.conv_ds_7 = ConvDW(n_base_units * 8, n_base_units * 16, 2)

            self.conv_ds_8 = ConvDW(n_base_units * 16, n_base_units * 16, 1)
            self.conv_ds_9 = ConvDW(n_base_units * 16, n_base_units * 16, 1)
            self.conv_ds_10 = ConvDW(n_base_units * 16, n_base_units * 16, 1)
            self.conv_ds_11 = ConvDW(n_base_units * 16, n_base_units * 16, 1)
            self.conv_ds_12 = ConvDW(n_base_units * 16, n_base_units * 16, 1)

            self.conv_ds_13 = ConvDW(n_base_units * 16, n_base_units * 32, 2)
            self.conv_ds_14 = ConvDW(n_base_units * 32, 4 + n_classes, 1)

        self.loss_calc = LossCalculator(n_classes,
                                        weight_noobj=0.2,
                                        class_weights=class_weights)

    def __call__(self, x, t):
        pred = self.predict(x)
        evaluated = self.loss_calc.loss(pred, t)
        chainer.report(evaluated, self)
        return evaluated['loss']

    def predict(self, x):
        h = self.conv_bn(x)
        h = self.conv_ds_2(h)
        h = self.conv_ds_3(h)
        h = self.conv_ds_4(h)
        h = self.conv_ds_5(h)
        h = self.conv_ds_6(h)
        h = self.conv_ds_7(h)
        h = self.conv_ds_8(h)
        h = self.conv_ds_9(h)
        h = self.conv_ds_10(h)
        h = self.conv_ds_11(h)
        h = self.conv_ds_12(h)
        h = self.conv_ds_13(h)
        h = self.conv_ds_14(h)

        # (batch_size, 4 + n_classes, 7, 7) -> (bach_size, 7, 7, 4 + n_classes)
        h = F.transpose(h, (0, 2, 3, 1))
        # (batch_size, 7, 7, 4 + n_classes) -> (batch_size, 49, 4 + n_classes)
        batch_size = int(h.size / (self.n_grid**2 * (4 + self.n_classes)))
        r = F.reshape(h, (batch_size, self.n_grid**2, 4 + self.n_classes))
        return r

    def to_gpu(self, *args, **kwargs):
        self.loss_calc.to_gpu()
        return super().to_gpu(*args, **kwargs)

    def to_cpu(self, *args, **kwargs):
        self.loss_calc.to_cpu()
        return super().to_gpu(*args, **kwargs)