def augment(out_dir, chkpt_path, train_loader, valid_loader, model, writer,
            logger, device, config):

    w_optim = utils.get_optim(model.weights(), config.w_optim)

    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        w_optim, config.epochs, eta_min=config.w_optim.lr_min)

    init_epoch = -1

    if chkpt_path is not None:
        logger.info("Resuming from checkpoint: %s" % chkpt_path)
        checkpoint = torch.load(chkpt_path)
        model.load_state_dict(checkpoint['model'])
        w_optim.load_state_dict(checkpoint['w_optim'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        init_epoch = checkpoint['epoch']
    else:
        logger.info("Starting new training run")

    logger.info("Model params count: {:.3f} M, size: {:.3f} MB".format(
        utils.param_size(model), utils.param_count(model)))

    # training loop
    logger.info('begin training')
    best_top1 = 0.
    tot_epochs = config.epochs
    for epoch in itertools.count(init_epoch + 1):
        if epoch == tot_epochs: break

        drop_prob = config.drop_path_prob * epoch / tot_epochs
        model.drop_path_prob(drop_prob)

        lr = lr_scheduler.get_lr()[0]

        # training
        train(train_loader, None, model, writer, logger, None, w_optim, None,
              lr, epoch, tot_epochs, device, config)
        lr_scheduler.step()

        # validation
        cur_step = (epoch + 1) * len(train_loader)
        top1 = validate(valid_loader, model, writer, logger, epoch, tot_epochs,
                        cur_step, device, config)

        # save
        if best_top1 < top1:
            best_top1 = top1
            is_best = True
        else:
            is_best = False

        if config.save_freq != 0 and epoch % config.save_freq == 0:
            save_checkpoint(out_dir, model, w_optim, None, lr_scheduler, epoch,
                            logger)

        print("")

    logger.info("Final best Prec@1 = {:.4%}".format(best_top1))
    tprof.stat_acc('model_' + NASModule.get_device()[0])
Example #2
0
def get_model(config, device, dev_list, genotype=None):
    mtype = config.type
    configure_ops(config)
    if mtype in model_creator:
        config.augment = not genotype is None
        net, arch = model_creator[mtype](config)
        crit = get_net_crit(config).to(device)
        prim = gt.get_primitives()
        model = NASController(config, net, crit, prim, dev_list).to(device)
        if config.augment:
            print("genotype = {}".format(genotype))
            model.build_from_genotype(genotype)
            model.to(device=device)
        if config.verbose: print(model)
        mb_params = param_size(model)
        n_params = param_count(model)
        print("Model params count: {:.3f} M, size: {:.3f} MB".format(n_params, mb_params))
        NASModule.set_device(dev_list)
        return model, arch
    else:
        raise Exception("invalid model type")
def save_checkpoint(out_dir, model, w_optim, a_optim, lr_scheduler, epoch,
                    logger):
    try:
        save_path = os.path.join(out_dir, 'chkpt_%03d.pt' % (epoch + 1))
        torch.save(
            {
                'model': model.state_dict(),
                'arch': NASModule.nasmod_state_dict(),
                'w_optim': w_optim.state_dict(),
                'a_optim': a_optim.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'epoch': epoch,
            }, save_path)
        logger.info("Saved checkpoint to: %s" % save_path)
    except Exception as e:
        logger.error("Save checkpoint failed: " + str(e))
Example #4
0
    def __init__(self, config, n_nodes, chn_in, stride, shared_a, allocator,
                 merger_state, merger_out, enumerator, preproc, aggregate,
                 edge_cls, edge_kwargs):
        super().__init__()
        global edge_id
        self.edge_id = edge_id
        edge_id += 1
        self.n_nodes = n_nodes
        self.stride = stride
        self.chn_in = chn_in
        self.n_input = len(chn_in)
        self.n_states = self.n_input + self.n_nodes
        self.n_input_e = len(edge_kwargs['chn_in'])
        self.shared_a = shared_a
        if shared_a:
            NASModule.add_shared_param()
        self.allocator = allocator(self.n_input, self.n_nodes)
        self.merger_state = merger_state()
        self.merger_out = merger_out(start=self.n_input)
        self.merge_out_range = self.merger_out.merge_range(self.n_states)
        self.enumerator = enumerator()

        chn_states = []
        if not preproc:
            self.preprocs = None
            chn_states.extend(chn_in)
        else:
            chn_cur = edge_kwargs['chn_in'][0]
            self.preprocs = nn.ModuleList()
            for i in range(self.n_input):
                self.preprocs.append(preproc[i](chn_in[i], chn_cur, False))
                chn_states.append(chn_cur)

        if not config.augment:
            self.fixed = False
            self.dag = nn.ModuleList()
            self.edges = []
            self.num_edges = 0
            for i in range(n_nodes):
                cur_state = self.n_input + i
                self.dag.append(nn.ModuleList())
                num_edges = self.enumerator.len_enum(cur_state, self.n_input_e)
                for sidx in self.enumerator.enum(cur_state, self.n_input_e):
                    e_chn_in = self.allocator.chn_in(
                        [chn_states[s] for s in sidx], sidx, cur_state)
                    edge_kwargs['chn_in'] = e_chn_in
                    edge_kwargs['stride'] = stride if all(s < self.n_input
                                                          for s in sidx) else 1
                    edge_kwargs['shared_a'] = shared_a
                    e = edge_cls(**edge_kwargs)
                    self.dag[i].append(e)
                    self.edges.append(e)
                self.num_edges += num_edges
                chn_states.append(
                    self.merger_state.chn_out(
                        [ei.chn_out for ei in self.dag[i]]))
                self.chn_out = self.merger_out.chn_out(chn_states)
            # print('DAGLayer: etype:{} chn_in:{} chn:{} #n:{} #e:{}'.format(str(edge_cls), self.chn_in, edge_kwargs['chn_in'][0],self.n_nodes, self.num_edges))
            # print('DAGLayer param count: {:.6f}'.format(param_count(self)))
        else:
            self.chn_states = chn_states
            self.edge_cls = edge_cls
            self.edge_kwargs = edge_kwargs
            self.fixed = True

        if aggregate is not None:
            self.merge_filter = aggregate(n_in=self.n_input + self.n_nodes,
                                          n_out=self.n_input +
                                          self.n_nodes // 2)
        else:
            self.merge_filter = None
        self.chn_out = self.merger_out.chn_out(chn_states)
Example #5
0
    def __init__(self,
                 config,
                 n_nodes,
                 chn_in,
                 stride,
                 shared_a,
                 allocator,
                 merger_out,
                 preproc,
                 aggregate,
                 child_cls,
                 child_kwargs,
                 edge_cls,
                 edge_kwargs,
                 children=None,
                 edges=None):
        super().__init__()
        self.edges = nn.ModuleList()
        self.subnets = nn.ModuleList()
        chn_in = (chn_in, ) if isinstance(chn_in, int) else chn_in
        self.n_input = len(chn_in)
        self.n_nodes = n_nodes
        self.n_states = self.n_input + self.n_nodes
        self.allocator = allocator(self.n_input, self.n_nodes)
        self.merger_out = merger_out(start=self.n_input)
        self.merge_out_range = self.merger_out.merge_range(self.n_states)
        if shared_a:
            NASModule.add_shared_param()

        chn_states = []
        if not preproc:
            self.preprocs = None
            chn_states.extend(chn_in)
        else:
            chn_cur = edge_kwargs['chn_in'][0]
            self.preprocs = nn.ModuleList()
            for i in range(self.n_input):
                self.preprocs.append(
                    preproc(chn_in[i], chn_cur, 1, 1, 0, False))
                chn_states.append(chn_cur)

        sidx = range(self.n_input)
        for i in range(self.n_nodes):
            e_chn_in = self.allocator.chn_in([chn_states[s] for s in sidx],
                                             sidx, i)
            if not edges is None:
                self.edges.append(edges[i])
                c_chn_in = edges[i].chn_out
            elif not edge_cls is None:
                edge_kwargs['chn_in'] = e_chn_in
                edge_kwargs['stride'] = stride
                if 'shared_a' in edge_kwargs:
                    edge_kwargs['shared_a'] = shared_a
                e = edge_cls(**edge_kwargs)
                self.edges.append(e)
                c_chn_in = e.chn_out
            else:
                self.edges.append(None)
                c_chn_in = e_chn_in
            if not children is None:
                self.subnets.append(children[i])
            elif not child_cls is None:
                child_kwargs['chn_in'] = c_chn_in
                self.subnets.append(child_cls(**child_kwargs))
            else:
                self.subnets.append(None)

        if aggregate is not None:
            self.merge_filter = aggregate(n_in=self.n_states,
                                          n_out=self.n_states // 2)
        else:
            self.merge_filter = None
    def step(self, trn_X, trn_y, val_X, val_y, xi, w_optim, a_optim):
        """ Compute unrolled loss and backward its gradients
        Args:
            xi: learning rate for virtual gradient step (same as net lr)
            w_optim: weights optimizer - for virtual step
        """
        a_optim.zero_grad()

        # sample k
        if self.sample:
            NASModule.param_module_call('sample_ops', n_samples=self.n_samples)

        loss = self.net.loss(val_X, val_y)

        m_out_dev = []
        for dev_id in NASModule.get_device():
            m_out = [
                m.get_state('m_out' + dev_id) for m in NASModule.modules()
            ]
            m_len = len(m_out)
            m_out_dev.extend(m_out)
        m_grad = torch.autograd.grad(loss, m_out_dev)
        for i, dev_id in enumerate(NASModule.get_device()):
            NASModule.param_backward_from_grad(
                m_grad[i * m_len:(i + 1) * m_len], dev_id)

        if not self.renorm:
            a_optim.step()
        else:
            # renormalization
            prev_pw = []
            for p, m in NASModule.param_modules():
                s_op = m.get_state('s_op')
                pdt = p.detach()
                pp = pdt.index_select(-1, s_op)
                if pp.size() == pdt.size(): continue
                k = torch.sum(torch.exp(pdt)) / torch.sum(torch.exp(pp)) - 1
                prev_pw.append(k)

            a_optim.step()

            for kprev, (p, m) in zip(prev_pw, NASModule.param_modules()):
                s_op = m.get_state('s_op')
                pdt = p.detach()
                pp = pdt.index_select(-1, s_op)
                k = torch.sum(torch.exp(pdt)) / torch.sum(torch.exp(pp)) - 1
                for i in s_op:
                    p[i] += (torch.log(k) - torch.log(kprev))

        NASModule.module_call('reset_ops')
def search(out_dir, chkpt_path, w_train_loader, a_train_loader, model, arch,
           writer, logger, device, config):
    valid_loader = a_train_loader

    w_optim = utils.get_optim(model.weights(), config.w_optim)
    a_optim = utils.get_optim(model.alphas(), config.a_optim)

    init_epoch = -1

    if chkpt_path is not None:
        logger.info("Resuming from checkpoint: %s" % chkpt_path)
        checkpoint = torch.load(chkpt_path)
        model.load_state_dict(checkpoint['model'])
        NASModule.nasmod_load_state_dict(checkpoint['arch'])
        w_optim.load_state_dict(checkpoint['w_optim'])
        a_optim.load_state_dict(checkpoint['a_optim'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        init_epoch = checkpoint['epoch']
    else:
        logger.info("Starting new training run")

    architect = arch(config, model)

    # warmup training loop
    logger.info('begin warmup training')
    try:
        if config.warmup_epochs > 0:
            warmup_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                w_optim, config.warmup_epochs, eta_min=config.w_optim.lr_min)
            last_epoch = 0
        else:
            last_epoch = -1

        tot_epochs = config.warmup_epochs
        for epoch in itertools.count(init_epoch + 1):
            if epoch == tot_epochs: break
            lr = warmup_lr_scheduler.get_lr()[0]
            # training
            train(w_train_loader, None, model, writer, logger, architect,
                  w_optim, a_optim, lr, epoch, tot_epochs, device, config)
            # validation
            cur_step = (epoch + 1) * len(w_train_loader)
            top1 = validate(valid_loader, model, writer, logger, epoch,
                            tot_epochs, cur_step, device, config)
            warmup_lr_scheduler.step()
            print("")
    except KeyboardInterrupt:
        print('skipped')

    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        w_optim,
        config.epochs,
        eta_min=config.w_optim.lr_min,
        last_epoch=last_epoch)

    save_checkpoint(out_dir, model, w_optim, a_optim, lr_scheduler, init_epoch,
                    logger)
    save_genotype(out_dir, model.genotype(), init_epoch, logger)

    # training loop
    logger.info('begin w/a training')
    best_top1 = 0.
    tot_epochs = config.epochs
    for epoch in itertools.count(init_epoch + 1):
        if epoch == tot_epochs: break
        lr = lr_scheduler.get_lr()[0]
        model.print_alphas(logger)
        # training
        train(w_train_loader, a_train_loader, model, writer, logger, architect,
              w_optim, a_optim, lr, epoch, tot_epochs, device, config)
        # validation
        cur_step = (epoch + 1) * len(w_train_loader)
        top1 = validate(valid_loader, model, writer, logger, epoch, tot_epochs,
                        cur_step, device, config)
        # genotype
        genotype = model.genotype()
        save_genotype(out_dir, genotype, epoch, logger)
        # genotype as image
        if config.plot:
            for i, dag in enumerate(model.dags()):
                plot_path = os.path.join(config.plot_path,
                                         "EP{:02d}".format(epoch + 1))
                caption = "Epoch {} - DAG {}".format(epoch + 1, i)
                plot(genotype.dag[i], dag, plot_path + "-dag_{}".format(i),
                     caption)
        if best_top1 < top1:
            best_top1 = top1
            best_genotype = genotype
        if config.save_freq != 0 and epoch % config.save_freq == 0:
            save_checkpoint(out_dir, model, w_optim, a_optim, lr_scheduler,
                            epoch, logger)
        lr_scheduler.step()
        print("")

    logger.info("Final best Prec@1 = {:.4%}".format(best_top1))
    logger.info("Best Genotype = {}".format(best_genotype))
    tprof.stat_acc('model_' + NASModule.get_device()[0])
    gt.to_file(best_genotype, os.path.join(out_dir, 'best.gt'))
def train(train_loader, valid_loader, model, writer, logger, architect,
          w_optim, a_optim, lr, epoch, tot_epochs, device, config):
    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()
    losses = utils.AverageMeter()

    cur_step = epoch * len(train_loader)
    writer.add_scalar('train/lr', lr, cur_step)

    model.train()

    if not valid_loader is None:
        tr_ratio = len(train_loader) // len(valid_loader)
        val_iter = iter(valid_loader)

    eta_m = utils.ETAMeter(tot_epochs, epoch, len(train_loader))
    eta_m.start()
    for step, (trn_X, trn_y) in enumerate(train_loader):
        trn_X, trn_y = trn_X.to(device,
                                non_blocking=True), trn_y.to(device,
                                                             non_blocking=True)
        N = trn_X.size(0)

        # phase 1. child network step (w)
        w_optim.zero_grad()
        tprof.timer_start('train')
        loss, logits = model.loss_logits(trn_X, trn_y, config.aux_weight)
        tprof.timer_stop('train')
        loss.backward()
        # gradient clipping
        if config.w_grad_clip > 0:
            nn.utils.clip_grad_norm_(model.weights(), config.w_grad_clip)
        w_optim.step()

        # phase 2. architect step (alpha)
        if not valid_loader is None and step % tr_ratio == 0:
            try:
                val_X, val_y = next(val_iter)
            except:
                val_iter = iter(valid_loader)
                val_X, val_y = next(val_iter)
            val_X, val_y = val_X.to(device, non_blocking=True), val_y.to(
                device, non_blocking=True)
            tprof.timer_start('arch')
            architect.step(trn_X, trn_y, val_X, val_y, lr, w_optim, a_optim)
            tprof.timer_stop('arch')

        prec1, prec5 = utils.accuracy(logits, trn_y, topk=(1, 5))
        losses.update(loss.item(), N)
        top1.update(prec1.item(), N)
        top5.update(prec5.item(), N)

        if step != 0 and step % config.print_freq == 0 or step == len(
                train_loader) - 1:
            eta = eta_m.step(step)
            logger.info(
                "Train: [{:2d}/{}] Step {:03d}/{:03d} LR {:.3f} Loss {losses.avg:.3f} "
                "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%}) | ETA: {eta}".
                format(epoch + 1,
                       tot_epochs,
                       step,
                       len(train_loader) - 1,
                       lr,
                       losses=losses,
                       top1=top1,
                       top5=top5,
                       eta=utils.format_time(eta)))

        writer.add_scalar('train/loss', loss.item(), cur_step)
        writer.add_scalar('train/top1', prec1.item(), cur_step)
        writer.add_scalar('train/top5', prec5.item(), cur_step)
        cur_step += 1

    logger.info("Train: [{:2d}/{}] Final Prec@1 {:.4%}".format(
        epoch + 1, tot_epochs, top1.avg))
    tprof.stat_acc('model_' + NASModule.get_device()[0])
    tprof.print_stat('train')
    tprof.print_stat('arch')