示例#1
0
def infer(valid_queue, model, epoch, Latency,criterion, writer):
    batch_time = utils.AverageMeters('Time', ':6.3f')
    losses = utils.AverageMeters('Loss', ':.4e')
    top1 = utils.AverageMeters('Acc@1', ':6.2f')
    top5 = utils.AverageMeters('Acc@5', ':6.2f')

    # set chosen op active
    model.module.set_chosen_op_active()
    model.module.unused_modules_off()

    model.eval()

    progress = utils.ProgressMeter(len(valid_queue), batch_time, losses, top1, top5,
                                   prefix='Test: ')
    cur_step = epoch*len(valid_queue)

    end = time.time()
    with torch.no_grad():
        for step, (input, target) in enumerate(valid_queue):
            # input = input.cuda()
            # target = target.cuda(non_blocking=True)
            input = Variable(input, volatile=True).cuda()
            # target = Variable(target, volatile=True).cuda(async=True)
            target = Variable(target, volatile=True).cuda()
            logits = model(input)
            loss = criterion(logits, target)
            acc1, acc5 = utils.accuracy(logits, target, topk=(1, 5))
            n = input.size(0)
            reduced_loss = reduce_tensor(
                loss.data, world_size=config.world_size)
            acc1 = reduce_tensor(acc1, world_size=config.world_size)
            acc5 = reduce_tensor(acc5, world_size=config.world_size)
            losses.update(to_python_float(reduced_loss), n)
            top1.update(to_python_float(acc1), n)
            top5.update(to_python_float(acc5), n)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            shape = [1, 3, 224, 224]
            input_var = torch.zeros(shape, device=device)
            flops = model.module.get_flops(input_var)
            if config.target_hardware in [None, 'flops']:
                latency = 0
            else:
                latency = Latency.predict_latency(model)

            model.module.unused_modules_back()

            if step % config.print_freq == 0:
                progress.print(step)
                logger.info('valid %03d\t loss: %e\t top1: %f\t top5: %f\t flops: %f\t latency: %f', step,
                            losses.avg, top1.avg, top5.avg, flops/1e6, latency)

    writer.add_scalar('val/loss', losses.avg, cur_step)
    writer.add_scalar('val/top1', top1.avg, cur_step)
    writer.add_scalar('val/top5', top5.avg, cur_step)
    return top1.avg, losses.avg
示例#2
0
def validate_warmup(valid_queue, model, epoch, criterion, writer):
    batch_time = utils.AverageMeters('Time', ':6.3f')
    losses = utils.AverageMeters('Loss', ':.4e')
    top1 = utils.AverageMeters('Acc@1', ':6.2f')
    top5 = utils.AverageMeters('Acc@5', ':6.2f')
    model.train()

    progress = utils.ProgressMeter(len(valid_queue),
                                   batch_time,
                                   losses,
                                   top1,
                                   top5,
                                   prefix='Warmup-Test: ')
    cur_step = epoch * len(valid_queue)

    end = time.time()
    with torch.no_grad():
        for step, (input, target) in enumerate(valid_queue):
            # input = input.cuda()
            # target = target.cuda(non_blocking=True)
            input = Variable(input, volatile=True).cuda()
            # target = Variable(target, volatile=True).cuda(async=True)
            target = Variable(target, volatile=True).cuda()
            logits = model(input)
            loss = criterion(logits, target)
            acc1, acc5 = utils.accuracy(logits, target, topk=(1, 5))
            n = input.size(0)

            losses.update(loss, n)
            top1.update(acc1, n)
            top5.update(acc5, n)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if step % config.print_freq == 0:
                progress.print(step)
                logger.info('warmup-valid %03d %e %f %f', step, losses.avg,
                            top1.avg, top5.avg)

    writer.add_scalar('warmup-val/loss', losses.avg, cur_step)
    writer.add_scalar('warmup-val/top1', top1.avg, cur_step)
    writer.add_scalar('warmup-val/top5', top5.avg, cur_step)
    return top1.avg, top5.avg, losses.avg
示例#3
0
    def train(self, loader, st_step=0, max_step=100000):

        self.gen.train()
        if self.disc is not None:
            self.disc.train()

        # loss stats
        losses = utils.AverageMeters("g_total", "pixel", "disc", "gen", "fm",
                                     "indp_exp", "indp_fact", "ac_s", "ac_c",
                                     "cross_ac_s", "cross_ac_c", "ac_gen_s",
                                     "ac_gen_c", "cross_ac_gen_s",
                                     "cross_ac_gen_c")
        # discriminator stats
        discs = utils.AverageMeters("real_font", "real_uni", "fake_font",
                                    "fake_uni", "real_font_acc",
                                    "real_uni_acc", "fake_font_acc",
                                    "fake_uni_acc")
        # etc stats
        stats = utils.AverageMeters("B", "ac_acc_s", "ac_acc_c",
                                    "ac_gen_acc_s", "ac_gen_acc_c")

        self.step = st_step
        self.clear_losses()

        self.logger.info("Start training ...")

        for batch in cyclize(loader):
            epoch = self.step // len(loader)
            if self.cfg.use_ddp and (self.step % len(loader)) == 0:
                loader.sampler.set_epoch(epoch)

            style_imgs = batch["style_imgs"].cuda()
            style_fids = batch["style_fids"].cuda()
            style_decs = batch["style_decs"]
            char_imgs = batch["char_imgs"].cuda()
            char_fids = batch["char_fids"].cuda()
            char_decs = batch["char_decs"]

            trg_imgs = batch["trg_imgs"].cuda()
            trg_fids = batch["trg_fids"].cuda()
            trg_cids = batch["trg_cids"].cuda()
            trg_decs = batch["trg_decs"]

            ##############################################################
            # infer
            ##############################################################

            B = len(trg_imgs)
            n_s = style_imgs.shape[1]
            n_c = char_imgs.shape[1]

            style_feats = self.gen.encode(style_imgs.flatten(
                0, 1))  # (B*n_s, n_exp, *feat_shape)
            char_feats = self.gen.encode(char_imgs.flatten(0, 1))

            self.add_indp_exp_loss(
                torch.cat([style_feats["last"], char_feats["last"]]))

            style_facts_s = self.gen.factorize(
                style_feats, 0)  # (B*n_s, n_exp, *feat_shape)
            style_facts_c = self.gen.factorize(style_feats, 1)
            char_facts_s = self.gen.factorize(char_feats, 0)
            char_facts_c = self.gen.factorize(char_feats, 1)

            self.add_indp_fact_loss(
                [style_facts_s["last"], style_facts_c["last"]],
                [style_facts_s["skip"], style_facts_c["skip"]],
                [char_facts_s["last"], char_facts_c["last"]],
                [char_facts_s["skip"], char_facts_c["skip"]],
            )

            mean_style_facts = {
                k: utils.add_dim_and_reshape(v, 0, (-1, n_s)).mean(1)
                for k, v in style_facts_s.items()
            }
            mean_char_facts = {
                k: utils.add_dim_and_reshape(v, 0, (-1, n_c)).mean(1)
                for k, v in char_facts_c.items()
            }
            gen_feats = self.gen.defactorize(
                [mean_style_facts, mean_char_facts])
            gen_imgs = self.gen.decode(gen_feats)

            stats.updates({
                "B": B,
            })

            real_font, real_uni, *real_feats = self.disc(
                trg_imgs, trg_fids, trg_cids, out_feats=self.cfg['fm_layers'])

            fake_font, fake_uni = self.disc(gen_imgs.detach(), trg_fids,
                                            trg_cids)
            self.add_gan_d_loss([real_font, real_uni], [fake_font, fake_uni])

            self.d_optim.zero_grad()
            self.d_backward()
            self.d_optim.step()

            fake_font, fake_uni, *fake_feats = self.disc(
                gen_imgs, trg_fids, trg_cids, out_feats=self.cfg['fm_layers'])
            self.add_gan_g_loss(fake_font, fake_uni)

            self.add_fm_loss(real_feats, fake_feats)

            def racc(x):
                return (x > 0.).float().mean().item()

            def facc(x):
                return (x < 0.).float().mean().item()

            discs.updates(
                {
                    "real_font": real_font.mean().item(),
                    "real_uni": real_uni.mean().item(),
                    "fake_font": fake_font.mean().item(),
                    "fake_uni": fake_uni.mean().item(),
                    'real_font_acc': racc(real_font),
                    'real_uni_acc': racc(real_uni),
                    'fake_font_acc': facc(fake_font),
                    'fake_uni_acc': facc(fake_uni)
                }, B)

            self.add_pixel_loss(gen_imgs, trg_imgs)

            self.g_optim.zero_grad()

            self.add_ac_losses_and_update_stats(
                torch.cat([style_facts_s["last"], char_facts_s["last"]]),
                torch.cat([style_fids.flatten(),
                           char_fids.flatten()]),
                torch.cat([style_facts_c["last"], char_facts_c["last"]]),
                style_decs + char_decs, gen_imgs, trg_fids, trg_decs, stats)
            self.ac_optim.zero_grad()
            self.ac_backward()
            self.ac_optim.step()

            self.g_backward()
            self.g_optim.step()

            loss_dic = self.clear_losses()
            losses.updates(loss_dic, B)  # accum loss stats

            # EMA g
            self.accum_g()
            if self.is_bn_gen:
                self.sync_g_ema(style_imgs, char_imgs)

            torch.cuda.synchronize()

            if self.cfg.gpu <= 0:
                if self.step % self.cfg.tb_freq == 0:
                    self.plot(losses, discs, stats)

                if self.step % self.cfg.print_freq == 0:
                    self.log(losses, discs, stats)
                    self.logger.debug(
                        "GPU Memory usage: max mem_alloc = %.1fM / %.1fM",
                        torch.cuda.max_memory_allocated() / 1000 / 1000,
                        torch.cuda.max_memory_cached() / 1000 / 1000)
                    losses.resets()
                    discs.resets()
                    stats.resets()

                    nrow = len(trg_imgs)
                    grid = utils.make_comparable_grid(trg_imgs.detach().cpu(),
                                                      gen_imgs.detach().cpu(),
                                                      nrow=nrow)
                    self.writer.add_image("last", grid)

                if self.step > 0 and self.step % self.cfg.val_freq == 0:
                    epoch = self.step / len(loader)
                    self.logger.info(
                        "Validation at Epoch = {:.3f}".format(epoch))

                    if not self.is_bn_gen:
                        self.sync_g_ema(style_imgs, char_imgs)

                    self.evaluator.comparable_val_saveimg(
                        self.gen_ema,
                        self.test_loader,
                        self.step,
                        n_row=self.test_n_row)

                    self.save(loss_dic['g_total'], self.cfg.save,
                              self.cfg.get('save_freq', self.cfg.val_freq))
            else:
                pass

            if self.step >= max_step:
                break

            self.step += 1

        self.logger.info("Iteration finished.")
示例#4
0
    def train(self, loader, st_step=1, val=None):
        val = val or {}
        self.gen.train()
        self.disc.train()

        # loss stats
        losses = utils.AverageMeters("g_total", "pixel", "disc", "gen", "fm",
                                     "ac", "ac_gen")
        # discriminator stats
        discs = utils.AverageMeters("real", "fake", "real_font", "real_char",
                                    "fake_font", "fake_char", "real_acc",
                                    "fake_acc", "real_font_acc",
                                    "real_char_acc", "fake_font_acc",
                                    "fake_char_acc")
        # etc stats
        stats = utils.AverageMeters("B_style", "B_target", "ac_acc",
                                    "ac_gen_acc")

        self.step = st_step
        self.clear_losses()

        self.logger.info("Start training ...")
        for (style_ids, style_char_ids, style_comp_ids, style_imgs, trg_ids,
             trg_char_ids, trg_comp_ids, trg_imgs,
             *content_imgs) in cyclize(loader):
            B = trg_imgs.size(0)
            stats.updates({"B_style": style_imgs.size(0), "B_target": B})

            style_ids = style_ids.cuda()
            #  style_char_ids = style_char_ids.cuda()
            style_comp_ids = style_comp_ids.cuda()
            style_imgs = style_imgs.cuda()
            trg_ids = trg_ids.cuda()
            trg_char_ids = trg_char_ids.cuda()
            trg_comp_ids = trg_comp_ids.cuda()
            trg_imgs = trg_imgs.cuda()

            # infer
            comp_feats = self.gen.encode_write(style_ids, style_comp_ids,
                                               style_imgs)
            out = self.gen.read_decode(trg_ids, trg_comp_ids)

            # D loss
            real, real_font, real_char, real_feats = self.disc(trg_imgs,
                                                               trg_ids,
                                                               trg_char_ids,
                                                               out_feats=True)
            fake, fake_font, fake_char = self.disc(out.detach(), trg_ids,
                                                   trg_char_ids)
            self.add_gan_d_loss(real, real_font, real_char, fake, fake_font,
                                fake_char)

            self.d_optim.zero_grad()
            self.d_backward()
            self.d_optim.step()

            # G loss
            fake, fake_font, fake_char, fake_feats = self.disc(out,
                                                               trg_ids,
                                                               trg_char_ids,
                                                               out_feats=True)
            self.add_gan_g_loss(real, real_font, real_char, fake, fake_font,
                                fake_char)

            # feature matching loss
            self.add_fm_loss(real_feats, fake_feats)

            # disc stats
            racc = lambda x: (x > 0.).float().mean().item()
            facc = lambda x: (x < 0.).float().mean().item()
            discs.updates(
                {
                    "real": real.mean().item(),
                    "fake": fake.mean().item(),
                    "real_font": real_font.mean().item(),
                    "real_char": real_char.mean().item(),
                    "fake_font": fake_font.mean().item(),
                    "fake_char": fake_char.mean().item(),
                    'real_acc': racc(real),
                    'fake_acc': facc(fake),
                    'real_font_acc': racc(real_font),
                    'real_char_acc': racc(real_char),
                    'fake_font_acc': facc(fake_font),
                    'fake_char_acc': facc(fake_char)
                }, B)

            # pixel loss
            self.add_pixel_loss(out, trg_imgs)

            self.g_optim.zero_grad()
            # NOTE ac loss generates & leaves grads to G.
            # so g_optim.zero_grad() should place in front of ac loss and
            # g_backward() should follow ac loss.
            if self.aux_clf is not None:
                self.add_ac_losses_and_update_stats(comp_feats, style_comp_ids,
                                                    out, trg_comp_ids, stats)

                self.ac_optim.zero_grad()
                self.ac_backward(retain_graph=True)
                self.ac_optim.step()

            self.g_backward()
            self.g_optim.step()

            loss_dic = self.clear_losses()
            losses.updates(loss_dic, B)

            # generator EMA
            self.accum_g()
            if self.is_bn_gen:
                self.sync_g_ema(style_ids, style_comp_ids, style_imgs, trg_ids,
                                trg_comp_ids)

            # after step
            if self.step % self.cfg['tb_freq'] == 0:
                self.plot(losses, discs, stats)

            if self.step % self.cfg['print_freq'] == 0:
                self.log(losses, discs, stats)
                losses.resets()
                discs.resets()
                stats.resets()

            if self.step % self.cfg['val_freq'] == 0:
                epoch = self.step / len(loader)
                self.logger.info("Validation at Epoch = {:.3f}".format(epoch))
                self.evaluator.merge_and_log_image('d1', out, trg_imgs,
                                                   self.step)
                self.evaluator.validation(self.gen, self.step)

                # if non-BN generator, sync max singular value of spectral norm.
                if not self.is_bn_gen:
                    self.sync_g_ema(style_ids, style_comp_ids, style_imgs,
                                    trg_ids, trg_comp_ids)
                self.evaluator.validation(self.gen_ema,
                                          self.step,
                                          extra_tag='_EMA')

                # save freq == val freq
                self.save(loss_dic['g_total'], self.cfg['save'],
                          self.cfg.get('save_freq', self.cfg['val_freq']))

            if self.step >= self.cfg['max_iter']:
                self.logger.info("Iteration finished.")
                break

            self.step += 1
示例#5
0
    def cross_validation(self,
                         gen,
                         step,
                         loader,
                         tag,
                         n_batches,
                         n_log=64,
                         save_dir=None):
        """Validation using splitted cross-validation set
        Args:
            n_log: # of images to log
            save_dir: if given, images are saved to save_dir
        """
        if save_dir:
            save_dir = Path(save_dir)
            save_dir.mkdir(parents=True, exist_ok=True)

        outs = []
        trgs = []
        n_accum = 0

        losses = utils.AverageMeters("l1", "ssim", "msssim")
        for i, (style_ids, style_comp_ids, style_imgs, trg_ids, trg_comp_ids,
                content_imgs, trg_imgs) in enumerate(loader):
            if i == n_batches:
                break

            style_ids = style_ids.cuda()
            style_comp_ids = style_comp_ids.cuda()
            style_imgs = style_imgs.cuda()
            trg_ids = trg_ids.cuda()
            trg_comp_ids = trg_comp_ids.cuda()
            trg_imgs = trg_imgs.cuda()

            gen.encode_write(style_ids, style_comp_ids, style_imgs)
            out = gen.read_decode(trg_ids, trg_comp_ids)
            B = len(out)

            # log images
            if n_accum < n_log:
                trgs.append(trg_imgs)
                outs.append(out)
                n_accum += B

                if n_accum >= n_log:
                    # log results
                    outs = torch.cat(outs)[:n_log]
                    trgs = torch.cat(trgs)[:n_log]
                    self.merge_and_log_image(tag, outs, trgs, step)

            l1, ssim, msssim = self.get_pixel_losses(out, trg_imgs,
                                                     self.unify_resize_method)
            losses.updates(
                {
                    "l1": l1.item(),
                    "ssim": ssim.item(),
                    "msssim": msssim.item()
                }, B)

            # save images
            if save_dir:
                font_ids = trg_ids.detach().cpu().numpy()
                images = out.detach().cpu()  # [B, 1, 128, 128]
                char_comp_ids = trg_comp_ids.detach().cpu().numpy(
                )  # [B, n_comp_types]
                for font_id, image, comp_ids in zip(font_ids, images,
                                                    char_comp_ids):
                    font_name = loader.dataset.fonts[font_id]  # name.ttf
                    font_name = Path(font_name).stem  # remove ext
                    (save_dir / font_name).mkdir(parents=True, exist_ok=True)
                    if self.language == 'kor':
                        char = kor.compose(*comp_ids)
                    elif self.language == 'thai':
                        char = thai.compose_ids(*comp_ids)

                    uni = "".join([f'{ord(each):04X}' for each in char])
                    path = save_dir / font_name / "{}_{}.png".format(
                        font_name, uni)
                    utils.save_tensor_to_image(image, path)

        self.logger.info(
            "  [Valid] {tag:30s} | Step {step:7d}  L1 {L.l1.avg:7.4f}  SSIM {L.ssim.avg:7.4f}"
            "  MSSSIM {L.msssim.avg:7.4f}".format(tag=tag, step=step,
                                                  L=losses))

        return losses.l1.avg, losses.ssim.avg, losses.msssim.avg
示例#6
0
def train(train_queue, valid_queue, model, criterion, LatencyLoss, optimizer,
          alpha_optimizer, lr, epoch, writer, update_schedule):

    arch_param_num = np.sum(
        np.prod(params.size()) for params in model.module.arch_parameters())
    binary_gates_num = len(list(model.module.binary_gates()))
    weight_param_num = len(list(model.module.weight_parameters()))
    print('#arch_params: %d\t#binary_gates: %d\t#weight_params: %d' %
          (arch_param_num, binary_gates_num, weight_param_num))

    batch_time = utils.AverageMeters('Time', ':6.3f')
    data_time = utils.AverageMeters('Data', ':6.3f')
    losses = utils.AverageMeters('Loss', ':.4e')
    top1 = utils.AverageMeters('Acc@1', ':6.2f')
    top5 = utils.AverageMeters('Acc@5', ':6.2f')
    entropy = utils.AverageMeters('Entropy', ':.4e')

    progress = utils.ProgressMeter(len(train_queue),
                                   batch_time,
                                   data_time,
                                   losses,
                                   top1,
                                   top5,
                                   prefix="Epoch: [{}]".format(epoch))
    cur_step = epoch * len(train_queue)
    writer.add_scalar('train/lr', lr, cur_step)

    model.train()
    end = time.time()
    for step, (input, target) in enumerate(train_queue):

        # measure data loading time
        data_time.update(time.time() - end)

        net_entropy = model.module.entropy()
        entropy.update(net_entropy.data.item() / arch_param_num, 1)

        # sample random path
        model.module.reset_binary_gates()
        # close unused module
        model.module.unused_modules_off()

        n = input.size(0)
        input = Variable(input, requires_grad=False).cuda()
        # target = Variable(target, requires_grad=False).cuda(async=True)
        target = Variable(target, requires_grad=False).cuda()

        logits = model(input)
        if config.label_smooth > 0.0:
            loss = utils.cross_entropy_with_label_smoothing(
                logits, target, config.label_smooth)
        else:
            loss = criterion(logits, target)

        acc1, acc5 = utils.accuracy(logits, target, topk=(1, 5))

        losses.update(loss, n)
        top1.update(acc1, n)
        top5.update(acc5, n)
        model.zero_grad()

        loss.backward()
        nn.utils.clip_grad_norm(model.parameters(), config.grad_clip)
        optimizer.step()
        # unused module back
        model.module.unused_modules_back()

        # Training weights firstly, after few epoch, train arch parameters
        if epoch > 0:
            #### office warm up lr ####
            # T_cur = epoch * len(train_queue) + step
            # lr_max = 0.05
            # T_totol = config.warmup_eforhs * len(train_queue)
            # lr = 0.5 * lr_max * (1 + math.cos(math.pi * T_cur / T_total))
            #### office warm up lr ####
            for j in range(update_schedule.get(step, 0)):
                model.train()
                latency_loss = 0
                expected_loss = 0

                valid_iter = iter(valid_queue)
                input_valid, target_valid = next(valid_iter)
                # alpha_optimizer.zero_grad()
                input_valid = Variable(input_valid, requires_grad=False).cuda()
                # target = Variable(target, requires_grad=False).cuda(async=True)
                target_valid = Variable(target_valid,
                                        requires_grad=False).cuda()
                model.module.reset_binary_gates()
                model.module.unused_modules_off()
                output_valid = model(input_valid).float()
                loss_ce = criterion(output_valid, target_valid)
                expected_loss = LatencyLoss.expected_latency(model)
                expected_loss_tensor = torch.cuda.FloatTensor([expected_loss])
                latency_loss = LatencyLoss(loss_ce, expected_loss_tensor,
                                           config)
                # compute gradient and do SGD step
                # zero grads of weight_param, arch_param & binary_param
                model.zero_grad()
                latency_loss.backward()
                # set architecture parameter gradients
                model.module.set_arch_param_grad()
                alpha_optimizer.step()
                model.module.rescale_updated_arch_param()
                model.module.unused_modules_back()
                log_str = 'Architecture [%d-%d]\t Loss %.4f\t %s LatencyLoss: %s' % (
                    epoch, step, latency_loss, config.target_hardware,
                    expected_loss)
                utils.write_log(arch_logger_path, log_str)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if step % config.print_freq == 0 or step == len(train_queue) - 1:
            logger.info('train step:%03d %03d  loss:%e top1:%05f top5:%05f',
                        step, len(train_queue), losses.avg, top1.avg, top5.avg)
            progress.print(step)
    writer.add_scalar('train/loss', losses.avg, cur_step)
    writer.add_scalar('train/top1', top1.avg, cur_step)
    writer.add_scalar('train/top5', top5.avg, cur_step)

    return top1.avg, losses.avg
示例#7
0
def warm_up(train_queue, valid_queue, model, criterion, Latency, optimizer,
            epoch, writer):
    batch_time = utils.AverageMeters('Time', ':6.3f')
    data_time = utils.AverageMeters('Data', ':6.3f')
    losses = utils.AverageMeters('Loss', ':.4e')
    top1 = utils.AverageMeters('Acc@1', ':6.2f')
    top5 = utils.AverageMeters('Acc@5', ':6.2f')
    progress = utils.ProgressMeter(len(train_queue),
                                   batch_time,
                                   data_time,
                                   losses,
                                   top1,
                                   top5,
                                   prefix="Epoch: [{}]".format(epoch))
    cur_step = epoch * len(train_queue)
    model.train()
    print('\n', '-' * 30, 'Warmup epoch: %d' % (epoch), '-' * 30, '\n')
    end = time.time()
    lr = 0
    for step, (input, target) in enumerate(train_queue):
        # measure data loading time
        data_time.update(time.time() - end)
        # office warm up lr #l'r
        T_cur = epoch * len(train_queue) + step
        lr_max = 0.05
        T_total = config.warmup_epochs * len(train_queue)
        lr = 0.5 * lr_max * (1 + math.cos(math.pi * T_cur / T_total))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        writer.add_scalar('warm-up/lr', lr, cur_step + step)

        #### office warm up lr ####

        n = input.size(0)
        input = Variable(input, requires_grad=False).cuda()
        # target = Variable(target, requires_grad=False).cuda(async=True)
        target = Variable(target, requires_grad=False).cuda()

        model.module.reset_binary_gates()
        model.module.unused_modules_off()

        logits = model(input)
        if config.label_smooth > 0 and epoch > config.warmup_epochs:
            loss = utils.cross_entropy_with_label_smoothing(
                logits, target, config.label_smooth)
        else:
            loss = criterion(logits, target)
        model.zero_grad()
        loss.backward()
        optimizer.step()

        acc1, acc5 = utils.accuracy(logits, target, topk=(1, 5))
        losses.update(loss, n)
        top1.update(acc1, n)
        top5.update(acc5, n)

        # unused modules back
        model.module.unused_modules_back()
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if step % config.print_freq == 0 or step == len(train_queue) - 1:
            logger.info(
                'warmup train step:%03d %03d  loss:%e top1:%05f top5:%05f',
                step, len(train_queue), losses.avg, top1.avg, top5.avg)
            progress.print(step)
        writer.add_scalar('warmup-train/loss', losses.avg, cur_step)
        writer.add_scalar('warmup-train/top1', top1.avg, cur_step)
        writer.add_scalar('warmup-train/top5', top5.avg, cur_step)

    logger.info('warmup epoch %d lr %e', epoch, lr)
    # set chosen op active
    model.module.set_chosen_op_active()
    # remove unused modules
    model.module.unused_modules_off()
    valid_top1, valid_top5, valid_loss = validate_warmup(
        valid_queue, model, epoch, criterion, writer)
    shape = [1, 3, 224, 224]
    input_var = torch.zeros(shape, device=device)
    flops = model.module.get_flops(input_var)
    latency = 0
    if config.target_hardware in [None, 'flops']:
        latency = 0
    else:
        latency = Latency.predict_latency(model)
    # unused modules back
    logger.info(
        'Warmup Valid [{0}/{1}]\tloss {2:.3f}\ttop-1 acc {3:.3f}\ttop-5 acc '
        '{4:.3f}\tflops: {5:.1f}M {6:.3f}ms'.format(epoch,
                                                    config.warmup_epochs,
                                                    valid_loss, valid_top1,
                                                    valid_top5, flops / 1e6,
                                                    latency))
    model.module.unused_modules_back()

    config.warmup = epoch + 1 < config.warmup_epochs
    state_dict = model.state_dict()
    # rm architect params and binary getes
    for key in list(state_dict.keys()):
        if 'alpha' in key or 'path' in key:
            state_dict.pop(key)
    checkpoint = {
        'state_dict': state_dict,
        'warmup': config.warmup,
    }
    if config.warmup:
        checkpoint['warmup_epoch'] = epoch

    checkpoint['epoch'] = epoch
    checkpoint['w_optimizer'] = optimizer.state_dict()

    save_model(model, checkpoint, model_name='warmup.pth.tar')
    return top1.avg, losses.avg
示例#8
0
    def train(self, loader, st_step=1, max_step=100000):

        self.gen.train()
        self.disc.train()

        losses = utils.AverageMeters("g_total", "pixel", "disc", "gen", "fm",
                                     "ac", "ac_gen", "dec_const")
        discs = utils.AverageMeters("real_font", "real_uni", "fake_font",
                                    "fake_uni", "real_font_acc",
                                    "real_uni_acc", "fake_font_acc",
                                    "fake_uni_acc")
        # etc stats
        stats = utils.AverageMeters("B_style", "B_target", "ac_acc",
                                    "ac_gen_acc")

        self.step = st_step
        self.clear_losses()

        self.logger.info("Start training ...")

        for (in_style_ids, in_comp_ids, in_imgs, trg_style_ids, trg_uni_ids,
             trg_comp_ids, trg_imgs, content_imgs) in cyclize(loader):

            epoch = self.step // len(loader)
            if self.cfg.use_ddp and (self.step % len(loader)) == 0:
                loader.sampler.set_epoch(epoch)

            B = trg_imgs.size(0)
            stats.updates({"B_style": in_imgs.size(0), "B_target": B})

            in_style_ids = in_style_ids.cuda()
            in_comp_ids = in_comp_ids.cuda()
            in_imgs = in_imgs.cuda()

            trg_style_ids = trg_style_ids.cuda()
            trg_imgs = trg_imgs.cuda()

            content_imgs = content_imgs.cuda()

            if self.cfg.use_half:
                in_imgs = in_imgs.half()
                content_imgs = content_imgs.half()

            feat_styles, feat_comps = self.gen.encode_write_fact(
                in_style_ids, in_comp_ids, in_imgs, write_comb=True)
            feats_rc = (feat_styles * feat_comps).sum(1)
            ac_feats = feats_rc
            self.add_dec_const_loss()

            out = self.gen.read_decode(trg_style_ids,
                                       trg_comp_ids,
                                       content_imgs=content_imgs,
                                       phase="fact",
                                       try_comb=True)

            trg_uni_disc_ids = trg_uni_ids.cuda()

            real_font, real_uni, *real_feats = self.disc(
                trg_imgs,
                trg_style_ids,
                trg_uni_disc_ids,
                out_feats=self.cfg['fm_layers'])

            fake_font, fake_uni = self.disc(out.detach(), trg_style_ids,
                                            trg_uni_disc_ids)
            self.add_gan_d_loss(real_font, real_uni, fake_font, fake_uni)

            self.d_optim.zero_grad()
            self.d_backward()
            self.d_optim.step()

            fake_font, fake_uni, *fake_feats = self.disc(
                out,
                trg_style_ids,
                trg_uni_disc_ids,
                out_feats=self.cfg['fm_layers'])
            self.add_gan_g_loss(real_font, real_uni, fake_font, fake_uni)

            self.add_fm_loss(real_feats, fake_feats)

            def racc(x):
                return (x > 0.).float().mean().item()

            def facc(x):
                return (x < 0.).float().mean().item()

            discs.updates(
                {
                    "real_font": real_font.mean().item(),
                    "real_uni": real_uni.mean().item(),
                    "fake_font": fake_font.mean().item(),
                    "fake_uni": fake_uni.mean().item(),
                    'real_font_acc': racc(real_font),
                    'real_uni_acc': racc(real_uni),
                    'fake_font_acc': facc(fake_font),
                    'fake_uni_acc': facc(fake_uni)
                }, B)

            self.add_pixel_loss(out, trg_imgs)

            self.g_optim.zero_grad()
            if self.aux_clf is not None:
                self.add_ac_losses_and_update_stats(ac_feats, in_comp_ids, out,
                                                    trg_comp_ids, stats)
                self.ac_optim.zero_grad()
                self.ac_backward()
                self.ac_optim.step()

            self.g_backward()
            self.g_optim.step()

            loss_dic = self.clear_losses()
            losses.updates(loss_dic, B)  # accum loss stats

            self.accum_g()
            if self.is_bn_gen:
                self.sync_g_ema(in_style_ids,
                                in_comp_ids,
                                in_imgs,
                                trg_style_ids,
                                trg_comp_ids,
                                content_imgs=content_imgs)

            torch.cuda.synchronize()

            if self.cfg.gpu <= 0:
                if self.step % self.cfg['tb_freq'] == 0:
                    self.baseplot(losses, discs, stats)
                    self.plot(losses)

                if self.step % self.cfg['print_freq'] == 0:
                    self.log(losses, discs, stats)
                    self.logger.debug(
                        "GPU Memory usage: max mem_alloc = %.1fM / %.1fM",
                        torch.cuda.max_memory_allocated() / 1000 / 1000,
                        torch.cuda.max_memory_cached() / 1000 / 1000)
                    losses.resets()
                    discs.resets()
                    stats.resets()

                if self.step % self.cfg['val_freq'] == 0:
                    epoch = self.step / len(loader)
                    self.logger.info(
                        "Validation at Epoch = {:.3f}".format(epoch))
                    if not self.is_bn_gen:
                        self.sync_g_ema(in_style_ids,
                                        in_comp_ids,
                                        in_imgs,
                                        trg_style_ids,
                                        trg_comp_ids,
                                        content_imgs=content_imgs)
                    self.evaluator.cp_validation(self.gen_ema,
                                                 self.cv_loaders,
                                                 self.step,
                                                 phase="fact",
                                                 ext_tag="factorize")

                    self.save(loss_dic['g_total'], self.cfg['save'],
                              self.cfg.get('save_freq', self.cfg['val_freq']))
            else:
                pass

            if self.step >= max_step:
                break

            self.step += 1

        self.logger.info("Iteration finished.")
示例#9
0
def validate(val_loader, model, epoch, criterion, config, early_stopping,
             writer, start):
    batch_time = utils.AverageMeters('Time', ':6.3f')
    losses = utils.AverageMeters('Loss', ':.4e')
    top1 = utils.AverageMeters('Acc@1', ':6.2f')
    top5 = utils.AverageMeters('Acc@5', ':6.2f')
    if 'DALIClassificationIterator' in val_loader.__class__.__name__:
        progress = utils.ProgressMeter(math.ceil(val_loader._size /
                                                 config.batch_size),
                                       batch_time,
                                       losses,
                                       top1,
                                       top5,
                                       prefix='Test: ')
    else:
        progress = utils.ProgressMeter(len(val_loader),
                                       batch_time,
                                       losses,
                                       top1,
                                       top5,
                                       prefix='Test: ')
    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        if 'DALIClassificationIterator' in val_loader.__class__.__name__:
            for i, data in enumerate(val_loader):
                images = Variable(data[0]['data'])
                target = Variable(
                    data[0]['label'].squeeze().cuda().long().cuda(
                        non_blocking=True))

                # compute output
                output = model(images)
                loss = criterion(output, target)

                # measure accuracy and record loss
                acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
                if config.distributed:
                    reduced_loss = reduce_tensor(loss.data,
                                                 world_size=config.world_size)
                    acc1 = reduce_tensor(acc1, world_size=config.world_size)
                    acc5 = reduce_tensor(acc5, world_size=config.world_size)
                else:
                    reduced_loss = loss.data
                losses.update(to_python_float(reduced_loss), images.size(0))
                top1.update(to_python_float(acc1), images.size(0))
                top5.update(to_python_float(acc5), images.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                if i % config.print_freq == 0:
                    progress.print(i)
        else:
            for i, (images, target) in enumerate(val_loader):
                images = images.cuda(device, non_blocking=True)
                target = target.cuda(device, non_blocking=True)

                # compute output
                output = model(images)
                loss = criterion(output, target)

                # measure accuracy and record loss
                acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
                losses.update(loss.item(), images.size(0))
                top1.update(acc1[0], images.size(0))
                top5.update(acc5[0], images.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                if i % config.print_freq == 0:
                    progress.print(i)
        # TODO: this should also be done with the ProgressMeter
        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1,
                                                                    top5=top5))

        early_stopping(losses.avg, model, ckpt_dir=config.path)
        if early_stopping.early_stop:
            print("Early stopping")
            utils.time(time.time() - start)
            os._exit(0)
        writer.add_scalar('val/loss', losses.avg, epoch)
        writer.add_scalar('val/top1', top1.val, epoch)
        writer.add_scalar('val/top5', top5.val, epoch)
    return top1.avg
示例#10
0
def train(train_loader, model, criterion, optimizer, epoch, config, writer):
    utils.adjust_learning_rate(optimizer, epoch, config)
    batch_time = utils.AverageMeters('Time', ':6.3f')
    data_time = utils.AverageMeters('Data', ':6.3f')
    losses = utils.AverageMeters('Loss', ':.4e')
    top1 = utils.AverageMeters('Acc@1', ':6.2f')
    top5 = utils.AverageMeters('Acc@5', ':6.2f')
    if 'DALIClassificationIterator' in train_loader.__class__.__name__:
        # TODO: IF need * config.world_size
        progress = utils.ProgressMeter(math.ceil(train_loader._size /
                                                 config.batch_size),
                                       batch_time,
                                       data_time,
                                       losses,
                                       top1,
                                       top5,
                                       prefix="Epoch: [{}]".format(epoch))
        cur_step = epoch * math.ceil(train_loader._size / config.batch_size)
    else:
        progress = utils.ProgressMeter(len(train_loader),
                                       batch_time,
                                       data_time,
                                       losses,
                                       top1,
                                       top5,
                                       prefix="Epoch: [{}]".format(epoch))

        cur_step = epoch * len(train_loader)
    writer.add_scalar('train/lr', config.lr, cur_step)

    model.train()

    end = time.time()
    if 'DALIClassificationIterator' in train_loader.__class__.__name__:
        for i, data in enumerate(train_loader):
            # measure data loading time
            data_time.update(time.time() - end)

            images = Variable(data[0]['data'])
            target = Variable(data[0]['label'].squeeze().cuda().long())

            # compute output
            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))

            if config.distributed:
                reduced_loss = reduce_tensor(loss.data,
                                             world_size=config.world_size)
                acc1 = reduce_tensor(acc1, world_size=config.world_size)
                acc5 = reduce_tensor(acc5, world_size=config.world_size)
            else:
                reduced_loss = loss.data
            losses.update(to_python_float(reduced_loss), images.size(0))
            top1.update(to_python_float(acc1), images.size(0))
            top5.update(to_python_float(acc5), images.size(0))

            # compute gradient and do SGD step
            optimizer.zero_grad()
            if config.fp16_allreduce:
                optimizer.backward(loss)
            else:
                loss.backward()
            optimizer.step()
            torch.cuda.synchronize()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            if i % config.print_freq == 0:
                progress.print(i)
            writer.add_scalar('train/loss', loss.item(), cur_step)
            writer.add_scalar('train/top1', top1.avg, cur_step)
            writer.add_scalar('train/top5', top5.avg, cur_step)
    else:
        for i, (images, target) in enumerate(train_loader):
            # measure data loading time
            data_time.update(time.time() - end)

            images = images.cuda(device, non_blocking=True)
            target = target.cuda(device, non_blocking=True)

            # compute output
            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1.item(), images.size(0))
            top5.update(acc5.item(), images.size(0))

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            if i % config.print_freq == 0:
                progress.print(i)

            writer.add_scalar('train/loss', loss.item(), cur_step)
            writer.add_scalar('train/top1', top1.avg, cur_step)
            writer.add_scalar('train/top5', top5.avg, cur_step)