示例#1
0
 def __init__(self):
     self.config = SearchConfig()
     self.writer = None
     if self.config.tb_dir != "":
         from torch.utils.tensorboard import SummaryWriter
         self.writer = SummaryWriter(self.config.tb_dir, flush_secs=20)
     init_gpu_params(self.config)
     set_seed(self.config)
     self.logger = FileLogger('./log', self.config.is_master,
                              self.config.is_master)
     self.load_data()
     self.logger.info(self.config)
     self.model = SearchCNNController(self.config, self.n_classes,
                                      self.output_mode)
     self.load_model()
     self.init_kd_component()
     if self.config.n_gpu > 0:
         self.model.to(device)
     if self.config.n_gpu > 1:
         self.model = torch.nn.parallel.DistributedDataParallel(
             self.model,
             device_ids=[self.config.local_rank],
             find_unused_parameters=True)
     self.model_to_print = self.model if self.config.multi_gpu is False else self.model.module
     self.architect = Architect(self.model, self.teacher_model, self.config,
                                self.emd_tool)
     mb_params = param_size(self.model)
     self.logger.info("Model size = {:.3f} MB".format(mb_params))
     self.eval_result_map = []
     self.init_optim()
示例#2
0
 def __init__(self):
     super().__init__()
     self.model = SearchCNNController(input_channels,
                                      config.init_channels,
                                      n_classes,
                                      config.layers,
                                      nn.CrossEntropyLoss().to(device),
                                      device_ids=config.gpus).to(device)
     self.input_device = torch.device(
         "cuda:0" if torch.cuda.is_available() else "cpu")
示例#3
0
class ParameterServer(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = SearchCNNController(input_channels,
                                         config.init_channels,
                                         n_classes,
                                         config.layers,
                                         nn.CrossEntropyLoss().to(device),
                                         device_ids=config.gpus).to(device)
        self.input_device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")

    def forward(self, inp):
        print("forwarding ps")
        inp = inp.to(self.input_device)
        out = self.model(inp)
        # This output is forwarded over RPC, which as of 1.5.0 only accepts CPU tensors.
        # Tensors must be moved in and out of GPU memory due to this.
        out = out.to("cpu")
        return out

    # Use dist autograd to retrieve gradients accumulated for this model.
    # Primarily used for verification.
    def get_dist_gradients(self, cid):
        grads = dist_autograd.get_gradients(cid)
        # This output is forwarded over RPC, which as of 1.5.0 only accepts CPU tensors.
        # Tensors must be moved in and out of GPU memory due to this.
        cpu_grads = {}
        for k, v in grads.items():
            k_cpu, v_cpu = k.to("cpu"), v.to("cpu")
            cpu_grads[k_cpu] = v_cpu
        return cpu_grads

    def genotype(self):
        return rpc.RRef(self.model.genotype())

    def weights(self):
        param_rrefs = [rpc.RRef(param) for param in self.model.parameters()]
        return param_rrefs

    def named_weights(self):
        param_rrefs = [
            rpc.RRef(param) for param in self.model.named_parameters()
        ]
        return param_rrefs

    def alphas(self):
        param_rrefs = [rpc.RRef(p) for n, p in self.model._alphas]
        return param_rrefs

    def named_alphas(self):
        param_rrefs = [(rpc.RRef(n), rpc.RRef(p))
                       for n, p in self.model._alphas]
        return param_rrefs
示例#4
0
文件: run.py 项目: kc-ml2/darts
def main():
    logger.info("Logger is set - training start")

    torch.cuda.set_device(config.gpus[0])

    # seed setting
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    torch.backends.cudnn.benchmark = True

    # get data with meta infomation
    input_size, input_channels, n_classes, train_data = utils.get_data(
        config.dataset, config.data_path, cutout_length=0, validation=False)

    # set model
    net_crit = nn.CrossEntropyLoss().to(device)
    model = SearchCNNController(input_channels,
                                config.init_channels,
                                n_classes,
                                config.layers,
                                net_crit,
                                n_nodes=config.nodes,
                                device_ids=config.gpus)
    model = model.to(device)

    # weight optim
    w_optim = torch.optim.SGD(model.weights(),
                              config.w_lr,
                              momentum=config.w_momentum,
                              weight_decay=config.alpha_weight_decay)

    # alpha optim
    alpha_optim = torch.optim.Adam(model.alphas(),
                                   config.alpha_lr,
                                   betas=(0.5, 0.999),
                                   weight_decay=config.alpha_weight_decay)

    # split data (train,validation)
    n_train = len(train_data)
    split = n_train // 2
    indices = list(range(n_train))
    train_sampler = torch.utils.data.sampler.SubsetRandomSampler(
        indices[:split])
    valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(
        indices[split:])

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=config.batch_size,
                                               sampler=train_sampler,
                                               num_workers=config.workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=config.batch_size,
                                               sampler=valid_sampler,
                                               num_workers=config.workers,
                                               pin_memory=True)

    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        w_optim, config.epochs, eta_min=config.w_lr_min)

    arch = Architect(model, config.w_momentum, config.w_weight_decay)

    # training loop-----------------------------------------------------------------------------
    best_top1 = 0.
    for epoch in range(config.epochs):
        lr_scheduler.step()
        lr = lr_scheduler.get_lr()[0]

        model.print_alphas(logger)

        #training
        train(train_loader, valid_loader, model, arch, w_optim, alpha_optim,
              lr, epoch)

        #validation
        cur_step = (epoch + 1) * len(train_loader)
        top1 = validate(valid_loader, model, epoch, cur_step)

        #log
        #genotype
        genotype = model.genotype()
        logger.info("genotype = {}".format(genotype))

        # genotype as a image
        plot_path = os.path.join(config.plot_path,
                                 "EP{:02d}".format(epoch + 1))
        caption = "Epoch {}".format(epoch + 1)
        plot(genotype.normal, plot_path + "-normal", caption)
        plot(genotype.reduce, plot_path + "-reduce", caption)

        # output alpha per epochs to tensorboard data
        for i, tensor in enumerate(model.alpha_normal):
            for j, lsn in enumerate(F.softmax(tensor, dim=-1)):
                tb_writer.add_scalars(
                    'epoch_alpha_normal/%d ~~ %d' % ((j - 2), i), {
                        'max_pl3': lsn[0],
                        'avg_pl3': lsn[1],
                        'skip_cn': lsn[2],
                        'sep_conv3': lsn[3],
                        'sep_conv5': lsn[4],
                        'dil_conv3': lsn[5],
                        'dil_conv5': lsn[6],
                        'none': lsn[7]
                    }, epoch)
        for i, tensor in enumerate(model.alpha_reduce):
            for j, lsr in enumerate(F.softmax(tensor, dim=-1)):
                tb_writer.add_scalars(
                    'epoch_alpha_reduce/%d ~~ %d' % ((j - 2), i), {
                        'max_pl3': lsr[0],
                        'avg_pl3': lsr[1],
                        'skip_cn': lsr[2],
                        'sep_conv3': lsr[3],
                        'sep_conv5': lsr[4],
                        'dil_conv3': lsr[5],
                        'dil_conv5': lsr[6],
                        'none': lsr[7]
                    }, epoch)

        #save
        if best_top1 < top1:
            best_top1 = top1
            best_genotype = genotype
            is_best = True
        else:
            is_best = False
        utils.save_checkpoint(model, config.path, is_best)
        print("")

    logger.info("Final best Prec@1 = {:.4%}".format(best_top1))
    logger.info("Best Genotype is = {}".format(best_genotype))
示例#5
0
def main():
    logger.info("Logger is set - training start")

    # set default gpu device id
    torch.cuda.set_device(config.gpus[0])

    # set seed
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    torch.backends.cudnn.benchmark = True

    # get data with meta info
    input_size, input_channels, n_classes, train_data = utils.get_data(
        config.dataset, config.data_path, cutout_length=0, validation=False)

    if config.ops_set == 2:
        gt.PRIMITIVES = gt.PRIMITIVES2

    if config.ops_set == 3:
        gt.PRIMITIVES = gt.PRIMITIVES_NO_SKIP

    if config.smart_sample:
        gt.smart_sample = True

    """ Initialize the distributed environment. """
    if config.multi_avg_size > 0:
        init_processes(config.multi_avg_rank, config.multi_avg_size, backend='Gloo')

    net_crit = nn.CrossEntropyLoss().to(device)
    model = SearchCNNController(input_channels, config.init_channels, n_classes, config.layers,
                                net_crit, n_nodes = config.n_nodes, device_ids=config.gpus, proxyless = config.proxyless)
    model = model.to(device)

    # weights optimizer
    w_optim = torch.optim.SGD(model.weights(), config.w_lr, momentum=config.w_momentum,
                              weight_decay=config.w_weight_decay)
    # alphas optimizer
    alpha_optim = torch.optim.Adam(model.alphas(), config.alpha_lr, betas=(0.5, 0.999),
                                   weight_decay=config.alpha_weight_decay)

    # split data to train/validation
    n_train = len(train_data)
    split = n_train // 2
    indices = list(range(n_train))
    train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])
    valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:])
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=config.batch_size,
                                               sampler=train_sampler,
                                               num_workers=config.workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=config.batch_size,
                                               sampler=valid_sampler,
                                               num_workers=config.workers,
                                               pin_memory=True)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        w_optim, config.epochs, eta_min=config.w_lr_min)
    architect = Architect(model, config.w_momentum, config.w_weight_decay, config.hv_type)

    # training loop
    best_top1 = 0.
    for epoch in range(config.epochs):
        lr_scheduler.step()
        lr = lr_scheduler.get_lr()[0]

        model.print_alphas(logger)

        # training
        train(train_loader, valid_loader, model, architect, w_optim, alpha_optim, lr, epoch)

        # validation
        cur_step = (epoch+1) * len(train_loader)
        top1 = validate(valid_loader, model, epoch, cur_step)

        # log
        # genotype
        genotype = model.genotype()
        logger.info("genotype = {}".format(genotype))

        # genotype as a image
        plot_path = os.path.join(config.plot_path, "EP{:02d}".format(epoch+1))
        caption = "Epoch {}".format(epoch+1)
        #plot(genotype.normal, plot_path + "-normal", caption)
        #plot(genotype.reduce, plot_path + "-reduce", caption)

        # save
        if best_top1 < top1:
            best_top1 = top1
            best_genotype = genotype
            is_best = True
        else:
            is_best = False
        utils.save_checkpoint(model, config.path, is_best)
        print("")

    logger.info("Final best Prec@1 = {:.4%}".format(best_top1))
    logger.info("Best Genotype = {}".format(best_genotype))
示例#6
0
def main():
    logger.info("Logger is set - training start")

    # set default gpu device id
    # torch.cuda.set_device(config.gpus[0])

    # set seed
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    # torch.cuda.manual_seed_all(config.seed)

    torch.backends.cudnn.benchmark = True

    # get data with meta info
    input_size, input_channels, n_classes, train_data = utils.get_data(
        config.dataset, config.data_path, cutout_length=0, validation=False
    )

    net_crit = nn.CrossEntropyLoss().to(device)
    model = SearchCNNController(
        input_channels,
        config.init_channels,
        n_classes,
        config.layers,
        net_crit,
        device_ids=config.gpus,
        imagenet_mode=config.dataset.lower() in utils.LARGE_DATASETS,
    )
    model = model.to(device)

    # weights optimizer
    w_optim = torch.optim.SGD(
        model.weights(),
        config.w_lr,
        momentum=config.w_momentum,
        weight_decay=config.w_weight_decay,
    )
    # alphas optimizer
    alpha_optim = torch.optim.Adam(
        model.alphas(),
        config.alpha_lr,
        betas=(0.5, 0.999),
        weight_decay=config.alpha_weight_decay,
    )

    # split data to train/validation
    n_train = len(train_data)
    split = n_train // 2
    indices = list(range(n_train))
    random.shuffle(indices)
    train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])
    valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:])
    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=config.batch_size,
        sampler=train_sampler,
        num_workers=config.workers,
        pin_memory=True,
    )
    valid_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=config.batch_size,
        sampler=valid_sampler,
        num_workers=config.workers,
        pin_memory=True,
    )
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        w_optim, config.epochs, eta_min=config.w_lr_min
    )
    architect = Architect(model, config.w_momentum, config.w_weight_decay)

    # training loop
    best_top1 = 0.0
    for epoch in range(config.epochs):
        lr_scheduler.step()
        lr = lr_scheduler.get_lr()[0]

        model.print_alphas(logger)

        # training
        train(
            train_loader, valid_loader, model, architect, w_optim, alpha_optim, lr, epoch
        )

        # validation
        cur_step = (epoch + 1) * len(train_loader)
        top1 = validate(valid_loader, model, epoch, cur_step)

        # log
        # genotype
        genotype = model.genotype()
        logger.info("genotype = {}".format(genotype))

        # genotype as a image
        plot_path = os.path.join(config.plot_path, "EP{:02d}".format(epoch + 1))
        caption = "Epoch {}".format(epoch + 1)
        plot(genotype.normal, plot_path + "-normal", caption)
        plot(genotype.reduce, plot_path + "-reduce", caption)

        # save
        if best_top1 < top1:
            best_top1 = top1
            best_genotype = genotype
            is_best = True
        else:
            is_best = False
        utils.save_checkpoint(model, config.path, is_best)
        print("")

    # restrict skip-co
    count = 0
    indices = []
    for i in range(4):
        _, primitive_indices = torch.topk(model.alpha_normal[i][:, :], 1)
        for j in range(2 + i):
            if primitive_indices[j].item() == 2:
                count = count + 1
                indices.append((i, j))

    while count > 2:
        alpha_min, indice_min = model.alpha_normal[indices[0][0]][indices[0][1], 2], 0
        for i in range(1, count):
            alpha_c = model.alpha_normal[indices[i][0]][indices[i][1], 2]
            if alpha_c < alpha_min:
                alpha_min, indice_min = alpha_c, i
        model.alpha_normal[indices[indice_min][0]][indices[indice_min][1], 2] = 0
        indices.pop(indice_min)
        print(indices)
        count = count - 1

    best_genotype = model.genotype()

    with open(config.path + "/best_genotype.txt", "w") as f:
        f.write(str(best_genotype))
    logger.info("Final best Prec@1 = {:.4%}".format(best_top1))
    logger.info("Best Genotype = {}".format(best_genotype))
示例#7
0
async def main():
    logger.info("Logger is set - training start")

    # set default gpu device id
    torch.cuda.set_device(config.gpus[0])

    # set seed
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    torch.backends.cudnn.benchmark = True

    # get data with meta info
    input_size, input_channels, n_classes, train_data = utils.get_data(
        config.dataset, config.data_path, cutout_length=0, validation=False)

    net_crit = nn.CrossEntropyLoss().to(default_device)
    model = SearchCNNController(input_channels,
                                config.init_channels,
                                n_classes,
                                config.layers,
                                net_crit,
                                device_ids=config.gpus)
    model = model.to(default_device)

    # weights optimizer
    w_optim = torch.optim.SGD(model.weights(),
                              config.w_lr,
                              momentum=config.w_momentum,
                              weight_decay=config.w_weight_decay)
    # alphas optimizer
    alpha_optim = torch.optim.Adam(model.alphas(),
                                   config.alpha_lr,
                                   betas=(0.5, 0.999),
                                   weight_decay=config.alpha_weight_decay)

    # split data to train/validation
    n_train = len(train_data)
    split = n_train // 2
    indices = list(range(n_train))
    train_sampler = torch.utils.data.sampler.SubsetRandomSampler(
        indices[:split])
    valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(
        indices[split:])
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=config.batch_size,
                                               sampler=train_sampler,
                                               num_workers=config.workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=config.batch_size,
                                               sampler=valid_sampler,
                                               num_workers=config.workers,
                                               pin_memory=True)

    for idx, (data, target) in enumerate(train_loader):
        wid = idx % len(workers)
        data = data.send(workers[wid])
        target = target.send(workers[wid])
        remote_train_data[wid].append((data, target))

    for idx, (data, target) in enumerate(valid_loader):
        wid = idx % len(workers)
        data = data.send(workers[wid])
        target = target.send(workers[wid])
        remote_valid_data[wid].append((data, target))

    print("finish sampler")

    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        w_optim, config.epochs, eta_min=config.w_lr_min)
    architect = Architect(model, config.w_momentum, config.w_weight_decay)

    # training loop
    best_top1 = 0.
    for epoch in range(config.epochs):
        lr_scheduler.step()
        lr = lr_scheduler.get_lr()[0]

        model.print_alphas(logger)

        # training
        await train(train_loader, valid_loader, model, architect, w_optim,
                    alpha_optim, lr, epoch)

        # validation
        cur_step = (epoch + 1) * len(train_loader)
        top1 = validate(valid_loader, model, epoch, cur_step)

        # log
        # genotype
        genotype = model.genotype()
        logger.info("genotype = {}".format(genotype))

        # genotype as a image
        plot_path = os.path.join(config.plot_path,
                                 "EP{:02d}".format(epoch + 1))
        caption = "Epoch {}".format(epoch + 1)
        plot(genotype.normal, plot_path + "-normal", caption)
        plot(genotype.reduce, plot_path + "-reduce", caption)

        # save
        if best_top1 < top1:
            best_top1 = top1
            best_genotype = genotype
            is_best = True
        else:
            is_best = False
        utils.save_checkpoint(model, config.path, is_best)

    logger.info("Final best Prec@1 = {:.4%}".format(best_top1))
    logger.info("Best Genotype = {}".format(best_genotype))
示例#8
0
def main():
    logger.info("Logger is set - training start")

    # set default gpu device id
    torch.cuda.set_device(config.gpus[0])

    # set seed
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    torch.backends.cudnn.benchmark = True

    # get data with meta info
    input_size, input_channels, n_classes, train_data = utils.get_data(
        config.dataset, config.data_path, cutout_length=0, validation=False)

    net_crit = nn.CrossEntropyLoss().to(device)
    model = SearchCNNController(input_channels,
                                config.init_channels,
                                n_classes,
                                config.layers,
                                net_crit,
                                device_ids=config.gpus)
    model = model.to(device)

    # weights optimizer
    w_optim = torch.optim.SGD(model.weights(),
                              config.w_lr,
                              momentum=config.w_momentum,
                              weight_decay=config.w_weight_decay)
    # alphas optimizer
    alpha_optim = torch.optim.Adam(model.alphas(),
                                   config.alpha_lr,
                                   betas=(0.5, 0.999),
                                   weight_decay=config.alpha_weight_decay)

    # split data to train/validation
    n_train = len(train_data)
    split = n_train // 2
    indices = list(range(n_train))
    train_sampler = torch.utils.data.sampler.SubsetRandomSampler(
        indices[:split])
    valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(
        indices[split:])
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=config.batch_size,
                                               sampler=train_sampler,
                                               num_workers=config.workers,
                                               pin_memory=False)
    valid_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=config.batch_size,
                                               sampler=valid_sampler,
                                               num_workers=config.workers,
                                               pin_memory=False)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        w_optim, config.epochs, eta_min=config.w_lr_min)
    architect = Architect(model, config.w_momentum, config.w_weight_decay)

    # training loop
    best_top1 = -1.0
    best_epoch = 0
    ################################ restore from last time #############################################
    epoch_restore = config.epoch_restore
    if config.restore:
        utils.load_state_dict(model,
                              config.path,
                              extra='model',
                              parallel=(len(config.gpus) > 1))
        if not config.model_only:
            utils.load_state_dict(w_optim,
                                  config.path,
                                  extra='w_optim',
                                  parallel=False)
            utils.load_state_dict(alpha_optim,
                                  config.path,
                                  extra='alpha_optim',
                                  parallel=False)
            utils.load_state_dict(lr_scheduler,
                                  config.path,
                                  extra='lr_scheduler',
                                  parallel=False)
            utils.load_state_dict(epoch_restore,
                                  config.path,
                                  extra='epoch_restore',
                                  parallel=False)
    #####################################################################################################
    for epoch in range(epoch_restore, config.epochs):
        lr_scheduler.step()
        lr = lr_scheduler.get_lr()[0]

        model.print_alphas(logger)

        # training
        train(train_loader, valid_loader, model, architect, w_optim,
              alpha_optim, lr, epoch)

        # validation
        cur_step = (epoch + 1) * len(train_loader)
        top1 = validate(valid_loader, model, epoch, cur_step)
        # top1 = 0.0

        # log
        # genotype
        genotype = model.genotype()
        logger.info("genotype = {}".format(genotype))

        # genotype as a image
        plot_path = os.path.join(config.plot_path,
                                 "EP{:02d}".format(epoch + 1))
        caption = "Epoch {}".format(epoch + 1)
        plot(genotype.normal, plot_path + "-normal", caption)
        plot(genotype.reduce, plot_path + "-reduce", caption)

        # save
        if best_top1 < top1:
            best_top1 = top1
            best_genotype = genotype
            is_best = True
            best_epoch = epoch + 1
        else:
            is_best = False
        utils.save_checkpoint(model, config.path, is_best)

        ######################################## save all state ###################################################
        utils.save_state_dict(model,
                              config.path,
                              extra='model',
                              is_best=is_best,
                              parallel=(len(config.gpus) > 1),
                              epoch=epoch + 1,
                              acc=top1,
                              last_state=((epoch + 1) >= config.epochs))
        utils.save_state_dict(lr_scheduler,
                              config.path,
                              extra='lr_scheduler',
                              is_best=is_best,
                              parallel=False,
                              epoch=epoch + 1,
                              acc=top1,
                              last_state=((epoch + 1) >= config.epochs))
        utils.save_state_dict(alpha_optim,
                              config.path,
                              extra='alpha_optim',
                              is_best=is_best,
                              parallel=False,
                              epoch=epoch + 1,
                              acc=top1,
                              last_state=((epoch + 1) >= config.epochs))
        utils.save_state_dict(w_optim,
                              config.path,
                              extra='w_optim',
                              is_best=is_best,
                              parallel=False,
                              epoch=epoch + 1,
                              acc=top1,
                              last_state=((epoch + 1) >= config.epochs))
        ############################################################################################################
        print("")
    logger.info("Best Genotype at {} epch.".format(best_epoch))
    logger.info("Final best Prec@1 = {:.4%}".format(best_top1))
    logger.info("Best Genotype = {}".format(best_genotype))
示例#9
0
from models.search_cnn import SearchCNNController
from visualize import plot

model = SearchCNNController(3, 16, 10, 20, None, n_nodes=4)
genotype = model.genotype()

print(genotype)

plot(genotype.normal, "normal_cell_genotype")
plot(genotype.reduce, "reduce_cell_genotype")
示例#10
0
def worker(gpu, ngpus_per_node, config_in):
    # init
    config = copy.deepcopy(config_in)
    jobid = os.environ["SLURM_JOBID"]
    procid = int(os.environ["SLURM_PROCID"])
    config.gpu = gpu

    if config.gpu is not None:
        writer_name = "tb.{}-{:d}-{:d}".format(jobid, procid, gpu)
        logger_name = "{}.{}-{:d}-{:d}.search.log".format(config.name, jobid, procid, gpu)
        ploter_name = "{}-{:d}-{:d}".format(jobid, procid, gpu)
        ck_name = "{}-{:d}-{:d}".format(jobid, procid, gpu)
    else:
        writer_name = "tb.{}-{:d}-all".format(jobid, procid)
        logger_name = "{}.{}-{:d}-all.search.log".format(config.name, jobid, procid)
        ploter_name = "{}-{:d}-all".format(jobid, procid)
        ck_name = "{}-{:d}-all".format(jobid, procid)

    writer = SummaryWriter(log_dir=os.path.join(config.path, writer_name))
    writer.add_text('config', config.as_markdown(), 0)
    logger = utils.get_logger(os.path.join(config.path, logger_name))

    config.print_params(logger.info)

    # get cuda device
    device = torch.device('cuda', gpu)

    # begin
    logger.info("Logger is set - training start")

    if config.dist_url == "env://" and config.rank == -1:
        config.rank = int(os.environ["RANK"])

    if config.mp_dist:
        # For multiprocessing distributed training, rank needs to be the
        # global rank among all the processes
        config.rank = config.rank * ngpus_per_node + gpu
    # print('back:{}, dist_url:{}, world_size:{}, rank:{}'.format(config.dist_backend, config.dist_url, config.world_size, config.rank))
    dist.init_process_group(backend=config.dist_backend, init_method=config.dist_url,
                            world_size=config.world_size, rank=config.rank)

    # get data with meta info
    input_size, input_channels, n_classes, train_data = utils.get_data(
        config.dataset, config.data_path, cutout_length=0, validation=False)

    # build model
    net_crit = nn.CrossEntropyLoss().to(device)
    model = SearchCNNController(input_channels, config.init_channels, n_classes, config.layers,
                                net_crit)
    if config.gpu is not None:
        torch.cuda.set_device(config.gpu)
        # model = model.to(device)
        model.cuda(config.gpu)
        # When using a single GPU per process and per DistributedDataParallel, we need to divide
        # the batch size ourselves based on the total number of GPUs we have
        config.batch_size = int(config.batch_size / ngpus_per_node)
        config.workers = int((config.workers + ngpus_per_node - 1) / ngpus_per_node)
        # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.rank])
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.gpu])
        # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=None, output_device=None)
    else:
        model.cuda()
        # DistributedDataParallel will divide and allocate batch_size to all
        # available GPUs if device_ids are not set
        model = torch.nn.parallel.DistributedDataParallel(model)

    # weights optimizer
    w_optim = torch.optim.SGD(model.module.weights(), config.w_lr, momentum=config.w_momentum,
                              weight_decay=config.w_weight_decay)
    # alphas optimizer
    alpha_optim = torch.optim.Adam(model.module.alphas(), config.alpha_lr, betas=(0.5, 0.999),
                                   weight_decay=config.alpha_weight_decay)

    # split data to train/validation
    n_train = len(train_data)
    split = n_train // 2
    indices = list(range(n_train))
    train_data_ = data.Subset(train_data, indices[:split])
    valid_data_ = data.Subset(train_data, indices[split:])
    train_sampler = data.distributed.DistributedSampler(train_data_,
                                                        num_replicas=config.world_size,
                                                        rank=config.rank)
    valid_sampler = data.distributed.DistributedSampler(valid_data_,
                                                        num_replicas=config.world_size,
                                                        rank=config.rank)
    train_loader = data.DataLoader(train_data_,
                                   batch_size=config.batch_size,
                                   sampler=train_sampler,
                                   shuffle=False,
                                   num_workers=config.workers,
                                   pin_memory=True)
    valid_loader = data.DataLoader(valid_data_,
                                   batch_size=config.batch_size,
                                   sampler=valid_sampler,
                                   shuffle=False,
                                   num_workers=config.workers,
                                   pin_memory=True)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        w_optim, config.epochs, eta_min=config.w_lr_min)
    architect = Architect(model, config.w_momentum, config.w_weight_decay)

    # setting the privacy protecting procedure
    if config.dist_privacy:
        logger.info("PRIVACY ENGINE OFF")

    # training loop
    best_top1 = 0.0
    for epoch in range(config.epochs):
        # lr_scheduler.step()
        # lr = lr_scheduler.get_lr()[0]
        lr = lr_scheduler.get_last_lr()[0]

        model.module.print_alphas(logger)

        # training
        train(logger, writer, device, config,
              train_loader, valid_loader, model, architect, w_optim, alpha_optim, lr, epoch)
        lr_scheduler.step()  # move to the place after optimizer.step()

        # validation
        cur_step = (epoch+1) * len(train_loader)
        top1 = validate(logger, writer, device, config,
                        valid_loader, model, epoch, cur_step)

        # log
        # genotype
        genotype = model.module.genotype()
        logger.info("genotype = {}".format(genotype))

        # genotype as a image
        plot_path = os.path.join(config.plot_path, "JOB" + ploter_name + "-EP{:02d}".format(epoch+1))
        caption = "Epoch {}".format(epoch+1)
        plot(genotype.normal, plot_path + "-normal", caption)
        plot(genotype.reduce, plot_path + "-reduce", caption)

        # save
        if best_top1 < top1:
            best_top1 = top1
            best_genotype = genotype
            is_best = True
        else:
            is_best = False
        utils.save_checkpoint(model, config.path, ck_name, 'search', is_best)
        print("")

    logger.info("Final best Prec@1 = {:.4%}".format(best_top1))
    logger.info("Best Genotype = {}".format(best_genotype))
示例#11
0
def main(config, writer, logger):
    logger.info("Logger is set - training start")

    # set seed
    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

    # get data with meta info
    input_size, input_channels, n_classes, train_data, test_data = utils.get_data(
        config.dataset, config.data_path, cutout_length=config.cutout_length, validation=True)

    net_crit = nn.CrossEntropyLoss().cuda()
    model = SearchCNNController(input_channels, config.init_channels, n_classes, config.n_layers, net_crit,
                                n_nodes=config.n_nodes, stem_multiplier=config.stem_multiplier,
                                bn_momentum=config.bn_momentum)
    model.cuda()

    # weights optimizer
    w_optim = optim.SGD(model.weights(), config.w_lr, momentum=config.w_momentum,
                        weight_decay=config.w_weight_decay)

    if not config.search_all_alpha:
        # alphas optimizer
        alpha_optim = optim.Adam(model.alphas(), config.alpha_lr, betas=(0.5, 0.999),
                                 weight_decay=config.alpha_weight_decay)
        # split data to train/validation
        n_train = len(train_data)
        indices = list(range(n_train))
        split = n_train // 2
        train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])
        valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:])
        train_loader = torch.utils.data.DataLoader(train_data,
                                                   batch_size=config.batch_size,
                                                   sampler=train_sampler,
                                                   num_workers=config.workers,
                                                   pin_memory=True)
        valid_loader = torch.utils.data.DataLoader(train_data,
                                                   batch_size=config.batch_size,
                                                   sampler=valid_sampler,
                                                   num_workers=config.workers,
                                                   pin_memory=True)
    else:
        alpha_optim = SubgraphSearchOptimizer(logger, config, model, w_optim)

        train_loader = torch.utils.data.DataLoader(train_data,
                                                   batch_size=config.batch_size,
                                                   num_workers=config.workers,
                                                   pin_memory=True,
                                                   shuffle=True)
        valid_loader = torch.utils.data.DataLoader(test_data,
                                                   batch_size=config.batch_size,
                                                   num_workers=config.workers,
                                                   pin_memory=True,
                                                   shuffle=True)

    if config.w_lr_scheduler == "cosine":
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(w_optim, T_max=config.epochs, eta_min=config.w_lr_min)
    elif config.w_lr_scheduler == "plateau":
        lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(w_optim, mode="max", patience=config.w_lr_patience,
                                                            factor=config.w_lr_factor, verbose=True)
    else:
        raise NotImplementedError
    architect = Architect(model, config.w_momentum, config.w_weight_decay)

    # training loop
    best_top1 = 0.
    best_genotype = None
    final_result_reported = False
    for epoch in range(config.epochs):

        if config.cutoff_epochs is not None and epoch >= config.cutoff_epochs:
            logger.info("Cutoff epochs detected, exiting.")
            break

        lr = w_optim.param_groups[0]["lr"]
        logger.info("Current learning rate: {}".format(lr))

        if lr < config.w_lr_min:
            logger.info("Learning rate is less than {}, exiting.".format(config.w_lr_min))
            break

        if not config.search_all_alpha:
            model.print_alphas(logger)
            valid_loader_for_training = valid_loader
        else:
            # make dummy input
            valid_loader_for_training = itertools.cycle([(torch.tensor(1), torch.tensor(1))])

        # training
        train(config, writer, logger, train_loader, valid_loader_for_training,
              model, architect, w_optim, alpha_optim, lr, epoch, valid_loader)

        if config.w_lr_scheduler == "cosine":
            lr_scheduler.step()

        # validation
        if config.validate_epochs == 0 or (epoch + 1) % config.validate_epochs != 0:
            logger.info("Valid: Skipping validation for epoch {}".format(epoch + 1))
            continue

        cur_step = (epoch + 1) * len(train_loader)
        if config.search_all_alpha:
            top1 = validate_all(config, writer, logger, valid_loader, model, epoch, cur_step, alpha_optim)
            if best_top1 < top1:
                best_top1 = top1

            # checkpoint saving is not supported yet
        else:
            top1 = validate(config, writer, logger, valid_loader, model, epoch, cur_step)

            # log
            # genotype
            genotype = model.genotype()
            logger.info("genotype = {}".format(genotype))

            # genotype as a image
            plot_path = os.path.join(config.plot_path, "EP{:02d}".format(epoch + 1))
            caption = "Epoch {}".format(epoch + 1)
            plot(genotype.normal, plot_path + "-normal", caption)

            # save
            if best_top1 < top1:
                best_top1 = top1
                best_genotype = genotype
                is_best = True
            else:
                is_best = False
            utils.save_checkpoint(model, config.path, is_best)

        if config.nni:
            nni_tools.report_result(top1, epoch + 1 == config.epochs)
            if epoch + 1 == config.epochs:
                final_result_reported = True

        if config.w_lr_scheduler == "plateau":
            lr_scheduler.step(top1)
        print("")

    if config.nni and not final_result_reported:
        try:
            nni_tools.report_result(top1, True)
        except:
            logger.warning("Final result not reported and top1 not found")

    logger.info("Final best Prec@1 = {:.4%}".format(best_top1))
    if best_genotype is not None:
        logger.info("Best Genotype = {}".format(best_genotype))
    print("")
示例#12
0
sys.path.append('..')
import torch
import torch.nn as nn
from models.search_cnn import SearchCNNController
import copy
import time
import argparse


parser = argparse.ArgumentParser()
parser.add_argument("--sanity", action="store_true", default=False, help="sanity check")
args = parser.parse_args()


net_crit = nn.CrossEntropyLoss().cuda()
src = SearchCNNController(3, 16, 10, 8, net_crit, device_ids=[0]).cuda()
tgt = SearchCNNController(3, 16, 10, 8, net_crit, device_ids=[0]).cuda()


### Settings ###
B = 64
if args.sanity:
    print("Sanity check ...")
    N = 1

    # fixed inputs
    gen_X, gen_y = torch.randn(B, 3, 32, 32).cuda(), torch.randint(10, [B], dtype=torch.long).cuda()

    def gen_inputs(B):
        return copy.deepcopy(gen_X), copy.deepcopy(gen_y)
else:
def main(config, writer, logger, checkpoint, base_step):
    logger.info("Pretrained checkpoint: {}".format(checkpoint))

    # set seed
    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

    # get data with meta info
    input_size, input_channels, n_classes, train_data, test_data = utils.get_data(
        config.dataset,
        config.data_path,
        cutout_length=config.cutout_length,
        validation=True)

    net_crit = nn.CrossEntropyLoss().cuda()
    model = SearchCNNController(input_channels,
                                config.init_channels,
                                n_classes,
                                config.n_layers,
                                net_crit,
                                n_nodes=config.n_nodes,
                                stem_multiplier=config.stem_multiplier,
                                bn_momentum=config.bn_momentum)
    model.cuda()
    model.load_state_dict(torch.load(checkpoint))

    base_epoch_number = base_step // (len(train_data) // config.batch_size)
    assert config.w_lr_scheduler == "cosine"
    base_lr = config.w_lr_min + (config.w_lr - config.w_lr_min) * \
              (1 + math.cos(math.pi * base_epoch_number / config.epochs)) / 2
    logger.info("Learning rate: {}".format(base_lr))

    # weights optimizer
    w_optim = optim.SGD(model.weights(),
                        base_lr,
                        momentum=config.w_momentum,
                        weight_decay=config.w_weight_decay)
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=config.batch_size,
                                               num_workers=config.workers,
                                               pin_memory=True,
                                               shuffle=True)
    valid_loader = torch.utils.data.DataLoader(test_data,
                                               batch_size=config.batch_size,
                                               num_workers=config.workers,
                                               pin_memory=True,
                                               shuffle=True)

    # training loop
    best_top1 = 0.
    for epoch in range(config.finetune_epochs):
        lr = w_optim.param_groups[0]["lr"]

        # training
        train(config, writer, logger, train_loader, valid_loader, model,
              w_optim, lr, epoch)

        # validation
        cur_step = (epoch + 1) * len(train_loader)
        if config.finetune_max_steps is None:
            top1 = validate(config,
                            writer,
                            logger,
                            valid_loader,
                            model,
                            epoch,
                            cur_step,
                            total_epochs=config.finetune_epochs)
        elif cur_step >= config.finetune_max_steps:
            break

        # save
        if best_top1 < top1:
            best_top1 = top1

    logger.info("Final best Prec@1 = {:.4%}".format(best_top1))
    print("")
示例#14
0
def main():
    logger.info("Logger is set - training start")

    # set default gpu device id
    torch.cuda.set_device(config.gpus[0])

    # get data with meta info
    input_size, input_channels, n_classes, train_data = utils.get_data(
        config.dataset, config.data_path, cutout_length=0, validation=False)

    net_crit = nn.CrossEntropyLoss().to(device)

    # set seed
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)
    model_1 = SearchCNNController(input_channels,
                                  config.init_channels,
                                  n_classes,
                                  config.layers,
                                  net_crit,
                                  device_ids=config.gpus)

    torch.manual_seed(config.seed + 1)
    torch.cuda.manual_seed_all(config.seed + 1)
    model_2 = SearchCNNController(input_channels,
                                  config.init_channels,
                                  n_classes,
                                  config.layers,
                                  net_crit,
                                  device_ids=config.gpus)

    torch.backends.cudnn.benchmark = True

    model_1 = model_1.to(device)
    model_2 = model_2.to(device)

    # weights optimizer
    w_optim_1 = torch.optim.SGD(model_1.weights(),
                                config.w_lr,
                                momentum=config.w_momentum,
                                weight_decay=config.w_weight_decay)
    # alphas optimizer
    alpha_optim_1 = torch.optim.Adam(model_1.alphas(),
                                     config.alpha_lr,
                                     betas=(0.5, 0.999),
                                     weight_decay=config.alpha_weight_decay)

    # weights optimizer
    w_optim_2 = torch.optim.SGD(model_2.weights(),
                                config.w_lr,
                                momentum=config.w_momentum,
                                weight_decay=config.w_weight_decay)
    # alphas optimizer
    alpha_optim_2 = torch.optim.Adam(model_2.alphas(),
                                     config.alpha_lr,
                                     betas=(0.5, 0.999),
                                     weight_decay=config.alpha_weight_decay)

    # split data to train/validation
    n_train = len(train_data)
    split = n_train // 2
    indices = list(range(n_train))
    train_sampler = torch.utils.data.sampler.SubsetRandomSampler(
        indices[:split])
    valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(
        indices[split:])
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=config.batch_size,
                                               sampler=train_sampler,
                                               num_workers=config.workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=config.batch_size,
                                               sampler=valid_sampler,
                                               num_workers=config.workers,
                                               pin_memory=True)
    lr_scheduler_1 = torch.optim.lr_scheduler.CosineAnnealingLR(
        w_optim_1, config.epochs, eta_min=config.w_lr_min)
    lr_scheduler_2 = torch.optim.lr_scheduler.CosineAnnealingLR(
        w_optim_2, config.epochs, eta_min=config.w_lr_min)
    architect = Architect(model_1, model_2, config.w_momentum,
                          config.w_weight_decay)

    # training loop
    best_top1_1 = 0.
    best_top1_2 = 0.
    for epoch in range(config.epochs):
        lr_scheduler_1.step()
        lr_1 = lr_scheduler_1.get_lr()[0]
        lr_scheduler_2.step()
        lr_2 = lr_scheduler_2.get_lr()[0]

        model_1.print_alphas(logger)
        model_2.print_alphas(logger)

        # training
        train(train_loader, valid_loader, model_1, model_2, architect,
              w_optim_1, w_optim_2, alpha_optim_1, alpha_optim_2, lr_1, lr_2,
              epoch, config.lmbda)

        # validation
        cur_step = (epoch + 1) * len(train_loader)
        top1_1, top1_2 = validate(valid_loader, model_1, model_2, epoch,
                                  cur_step)

        # log
        # genotype
        genotype_1 = model_1.genotype()
        genotype_2 = model_2.genotype()
        logger.info("genotype_1 = {}".format(genotype_1))
        logger.info("genotype_2 = {}".format(genotype_2))

        # genotype as a image
        # plot_path = os.path.join(config.plot_path, "EP{:02d}".format(epoch+1))
        # caption = "Epoch {}".format(epoch+1)
        # plot(genotype_1.normal, plot_path + "-normal", caption)
        # plot(genotype_1.reduce, plot_path + "-reduce", caption)
        # plot(genotype_2.normal, plot_path + "-normal", caption)
        # plot(genotype_2.reduce, plot_path + "-reduce", caption)

        # save
        if best_top1_1 < top1_1:
            best_top1_1 = top1_1
            best_genotype_1 = genotype_1
            is_best_1 = True
        else:
            is_best_1 = False

        if best_top1_2 < top1_2:
            best_top1_2 = top1_2
            best_genotype_2 = genotype_2
            is_best_2 = True
        else:
            is_best_2 = False

        utils.save_checkpoint(model_1, config.path, 1, is_best_1)
        utils.save_checkpoint(model_2, config.path, 2, is_best_2)
        print("")

    logger.info("Final best Prec@1_1 = {:.4%}".format(best_top1_1))
    logger.info("Best Genotype_1 = {}".format(best_genotype_1))
    logger.info("Final best Prec@1_2 = {:.4%}".format(best_top1_2))
    logger.info("Best Genotype_2 = {}".format(best_genotype_2))
示例#15
0
def main():
    logger.info("Logger is set - training start")
    fileRoot = r'/home/hlu/Data/VIPL'
    saveRoot = r'/home/hlu/Data/VIPL_STMap' + str(config.fold_num) + str(
        config.fold_index)
    n_classes = 1
    input_channels = 3
    # set default gpu device id
    torch.cuda.set_device(config.gpus[0])
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    toTensor = transforms.ToTensor()
    resize = transforms.Resize(size=(64, 300))
    # set seed
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    torch.backends.cudnn.benchmark = True  # 网络加速

    if config.reData == 1:
        test_index, train_index = MyDataset.CrossValidation(
            fileRoot, fold_num=config.fold_num, fold_index=config.fold_index)
    train_data = MyDataset.Data_STMap(root_dir=(saveRoot + '_Train'),
                                      frames_num=300,
                                      transform=transforms.Compose(
                                          [resize, toTensor, normalize]))
    net_crit = nn.L1Loss().to(device)
    model = SearchCNNController(input_channels,
                                config.init_channels,
                                n_classes,
                                config.layers,
                                net_crit,
                                device_ids=config.gpus)
    model._init_weight()
    model = model.to(device)
    # weights optimizer
    w_optim = torch.optim.SGD(model.weights(),
                              config.w_lr,
                              momentum=config.w_momentum,
                              weight_decay=config.w_weight_decay)
    # w_optim = torch.optim.Adam(model.weights(), config.w_lr)
    # alphas optimizer
    alpha_optim = torch.optim.Adam(model.alphas(),
                                   config.alpha_lr,
                                   betas=(0.5, 0.999),
                                   weight_decay=config.alpha_weight_decay)

    # split data to train/validation
    n_train = len(train_data)
    split = n_train // 2
    indices = list(range(n_train))
    train_sampler = torch.utils.data.sampler.SubsetRandomSampler(
        indices[:split])
    valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(
        indices[split:])
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=config.batch_size,
                                               sampler=train_sampler,
                                               num_workers=config.workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=config.batch_size,
                                               sampler=valid_sampler,
                                               num_workers=config.workers,
                                               pin_memory=True)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        w_optim, config.epochs, eta_min=config.w_lr_min)
    architect = Architect(model, config.w_momentum, config.w_weight_decay)
    # training loop
    best_losses = 100
    for epoch in range(config.epochs):
        lr_scheduler.step()
        lr = lr_scheduler.get_lr()[0]
        model.print_alphas(logger)
        # training
        train(train_loader, valid_loader, model, architect, w_optim,
              alpha_optim, lr, epoch)
        # validation
        cur_step = (epoch + 1) * len(train_loader)
        losses = validate(valid_loader, model, epoch, cur_step)
        # log
        # genotype
        genotype = model.genotype()
        logger.info("genotype = {}".format(genotype))

        # save
        if losses < best_losses:
            best_losses = losses
            best_genotype = genotype
            is_best = True
        else:
            is_best = False
        utils.save_checkpoint(model, config.path, is_best)
        print("")

    logger.info("Best Genotype = {}".format(best_genotype))
示例#16
0
def main():
    logger.info("Logger is set - training start")

    # set default gpu device id
    torch.cuda.set_device(config.gpus[0])

    # set seed
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    torch.backends.cudnn.benchmark = True

    # get data with meta info
    dahai_train_dataset = utils.MyDataset(data_dir=TRAIN_DATA_PATH, )
    dahai_dev_dataset = utils.MyDataset(data_dir=DEV_DAHAI_DATA_PATH, )
    # zhikang_test_dataset = utils.MyDataset(window_size=WINDOW_SIZE,
    #                                     window_step=WINDOW_STEP_DEV,
    #                                     data_path=TEST_ZHIKANG_DATA_PATH,
    #                                     voice_embed_path=TEST_ZHIKANG_VOICE_EMBEDDING_PATH,
    #                                     w2i=w2i,
    #                                     sent_max_len=SENT_MAX_LEN,
    #                                     )

    train_data = utils.DataProvider(batch_size=config.batch_size,
                                    dataset=dahai_train_dataset,
                                    is_cuda=config.is_cuda)
    dev_data = utils.DataProvider(batch_size=config.batch_size,
                                  dataset=dahai_dev_dataset,
                                  is_cuda=config.is_cuda)
    # test_data = utils.DataProvider(batch_size=config.batch_size, dataset=zhikang_test_dataset, is_cuda=config.is_cuda)

    print("train data nums:", len(train_data.dataset), "dev data nums:",
          len(dev_data.dataset))

    net_crit = nn.CrossEntropyLoss(reduction="none").to(device)
    model = SearchCNNController(config.embedding_dim,
                                config.init_channels,
                                config.n_classes,
                                config.layers,
                                net_crit,
                                config=config,
                                n_nodes=config.n_nodes,
                                device_ids=config.gpus)
    model = model.to(device).float()

    # weights optimizer
    w_optim = torch.optim.SGD(model.weights(),
                              config.w_lr,
                              momentum=config.w_momentum,
                              weight_decay=config.w_weight_decay)
    # alphas optimizer
    alpha_optim = torch.optim.Adam(model.alphas(),
                                   config.alpha_lr,
                                   betas=(0.5, 0.999),
                                   weight_decay=config.alpha_weight_decay)

    ######  余弦退火-调整学习率
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        w_optim, config.epochs, eta_min=config.w_lr_min)
    architect = Architect(model, config.w_momentum, config.w_weight_decay)

    # training loop
    best_acc = 0.
    best_genotype = model.genotype()
    while True:
        epoch = train_data.epoch
        if epoch > config.epochs - 1:
            break
        lr_scheduler.step()
        lr = lr_scheduler.get_lr()[0]

        model.print_alphas(logger)

        # training
        train(train_data, dev_data, epoch, model, architect, w_optim,
              alpha_optim, lr)

        # validation
        cur_step = train_data.iteration
        valid_acc = validate(dev_data, model, epoch, cur_step)

        # log
        # genotype
        genotype = model.genotype()
        logger.info("genotype = {}".format(genotype))

        # genotype as a image
        plot_path = os.path.join(config.plot_path,
                                 "EP{:02d}".format(epoch + 1))
        caption = "Epoch {}".format(epoch + 1)
        plot(genotype.normal, plot_path + "-normal", caption)
        plot(genotype.reduce, plot_path + "-reduce", caption)

        # save
        if best_acc < valid_acc:
            best_acc = valid_acc
            best_genotype = genotype
            is_best = True
        else:
            is_best = False
        utils.save_checkpoint(model, config.path, is_best)
        print("")

    logger.info("Final best Prec@1 = {:.4%}".format(best_acc))
    logger.info("Best Genotype = {}".format(best_genotype))
示例#17
0
def run_worker():
    rpc.init_rpc(name=f"trainer_{config.rank}",
                 rank=config.rank,
                 world_size=config.world_size)
    logger.info("Logger is set - training start")

    # set default gpu device id
    torch.cuda.set_device(config.gpus[0])

    # set seed
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    torch.backends.cudnn.benchmark = True

    # get data with meta info
    # input_size, input_channels, n_classes, train_data = utils.get_data(
    #     config.dataset, config.data_path, cutout_length=0, validation=False)
    #
    net_crit = nn.CrossEntropyLoss().to(device)
    # model = SearchCNNController(input_channels, config.init_channels, n_classes, config.layers,
    #                             net_crit, device_ids=config.gpus)
    # model = model.to(device)
    model = TrainerNet(net_crit)

    # weights optimizer
    # w_optim = torch.optim.SGD(model.weights(), config.w_lr, momentum=config.w_momentum,
    #                           weight_decay=config.w_weight_decay)
    w_optim = DistributedOptimizer(torch.optim.SGD,
                                   model.weights(),
                                   lr=config.w_lr,
                                   momentum=config.w_momentum,
                                   weight_decay=config.w_weight_decay)
    # alphas optimizer
    # alpha_optim = torch.optim.Adam(model.alphas(), config.alpha_lr, betas=(0.5, 0.999),
    #                                weight_decay=config.alpha_weight_decay)
    alpha_optim = DistributedOptimizer(torch.optim.Adam,
                                       model.alphas(),
                                       lr=config.alpha_lr,
                                       betas=(0.5, 0.999),
                                       weight_decay=config.alpha_weight_decay)

    # split data to train/validation
    n_train = len(train_data)
    split = n_train // 2
    world = config.world_size
    rank = config.rank
    indices = list(range(n_train))
    train_sampler = torch.utils.data.sampler.SubsetRandomSampler(
        indices[int(rank * split / world):int((rank + 1) * split / world)])
    valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(
        indices[split + int(rank * (n_train - split) / world):split +
                int(int((rank + 1) * (n_train - split) / world))])
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=config.batch_size,
                                               sampler=train_sampler,
                                               num_workers=config.workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=config.batch_size,
                                               sampler=valid_sampler,
                                               num_workers=config.workers,
                                               pin_memory=True)

    # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    #     w_optim, config.epochs, eta_min=config.w_lr_min)
    lrs_rrefs = []
    for opt_rref in w_optim.remote_optimizers:
        lrs_rrefs.append(
            rpc.remote(opt_rref.owner(),
                       create_lr_scheduler,
                       args=(opt_rref, )))

    v_model = SearchCNNController(input_channels,
                                  config.init_channels,
                                  n_classes,
                                  config.layers,
                                  nn.CrossEntropyLoss().to(device),
                                  device_ids=config.gpus).to(device)
    architect = Architect(model, v_model, config.w_momentum,
                          config.w_weight_decay, noise_add)

    if noise_add:
        logger.info("Adding noise")
        for param in model.parameters():
            shape_gaussian[param.data.shape] = gaussian.MultivariateNormal(
                torch.zeros(param.data.shape), torch.eye(param.data.shape[-1]))
    else:
        logger.info("Not adding noise")

    # training loop
    best_top1 = 0.
    for epoch in range(config.epochs):

        with dist_autograd.context() as cid:
            futs = []
            for lrs_rref in lrs_rrefs:
                futs.append(
                    rpc.rpc_async(lrs_rref.owner(),
                                  lrs_step,
                                  args=(lrs_rref, )))
            [fut.wait() for fut in futs]
            lr = remote_method(get_lrs_value,
                               lrs_rrefs.owner(),
                               args=(lrs_rrefs[0], ))
        # lr_scheduler.step()
        # lr = lr_scheduler.get_lr()[0]

        # model.print_alphas(logger)

        # training
        train(train_loader, valid_loader, model, architect, w_optim,
              alpha_optim, lr, epoch)

        # validation
        cur_step = (epoch + 1) * len(train_loader)
        top1 = validate(valid_loader, model, epoch, cur_step)

        # log
        # genotype
        genotype = model.genotype()
        logger.info("genotype = {}".format(genotype))

        # genotype as a image
        plot_path = os.path.join(config.plot_path,
                                 "EP{:02d}".format(epoch + 1))
        caption = "Epoch {}".format(epoch + 1)
        plot(genotype.normal, plot_path + "-normal", caption)
        plot(genotype.reduce, plot_path + "-reduce", caption)

        # save
        if best_top1 < top1:
            best_top1 = top1
            best_genotype = genotype
            is_best = True
        else:
            is_best = False
        utils.save_checkpoint(model, config.path, is_best)
        print("")

    logger.info("Final best Prec@1 = {:.4%}".format(best_top1))
    logger.info("Best Genotype = {}".format(best_genotype))
    rpc.shutdown()
示例#18
0
def main():
    logger.info("Logger is set - training start")

    # set default gpu device id
    torch.cuda.set_device(config.gpus[0])

    # set seed
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    torch.backends.cudnn.benchmark = True

    # get data with meta info
    input_size, input_channels, n_classes, train_data, val_dat, test_dat = utils.get_data(
        config.dataset,
        config.data_path,
        cutout_length=0,
        validation=True,
        validation2=True,
        img_resize=config.img_resize)

    net_crit = nn.CrossEntropyLoss().to(device)
    model = SearchCNNController(input_channels,
                                config.init_channels,
                                n_classes,
                                config.layers,
                                net_crit,
                                device_ids=config.gpus)
    #comment if generating onnix graph
    model = model.to(device)

    # weights optimizer
    w_optim = torch.optim.SGD(model.weights(),
                              config.w_lr,
                              momentum=config.w_momentum,
                              weight_decay=config.w_weight_decay)
    # alphas optimizer
    alpha_optim = torch.optim.Adam(model.alphas(),
                                   config.alpha_lr,
                                   betas=(0.5, 0.999),
                                   weight_decay=config.alpha_weight_decay)

    #balanced split to train/validation
    print(train_data)

    # split data to train/validation
    n_train = len(train_data) // int(config.data_train_proportion)
    n_val = len(val_dat)
    n_test = len(test_dat)
    split = n_train // 2
    indices1 = list(range(n_train))
    indices2 = list(range(n_val))
    indices3 = list(range(n_test))
    train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices1)
    valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices2)
    test_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices3)

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=config.batch_size,
                                               sampler=train_sampler,
                                               num_workers=config.workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(val_dat,
                                               batch_size=config.batch_size,
                                               sampler=valid_sampler,
                                               num_workers=config.workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_dat,
                                              batch_size=config.batch_size,
                                              sampler=test_sampler,
                                              num_workers=config.workers,
                                              pin_memory=True)

    #load
    if (config.load):
        model, config.epochs, w_optim, alpha_optim, net_crit = utils.load_checkpoint(
            model, config.epochs, w_optim, alpha_optim, net_crit,
            '/content/MyDarts/searchs/custom/checkpoint.pth.tar')
    #uncomment if saving onnix graph
    """
    dummy_input = Variable(torch.randn(1, 3, 64, 64))
    torch.onnx.export(model, dummy_input, "rsdarts.onnx", verbose=True)
    input_np = np.random.uniform(0, 1, (1, 3, 64, 64))
    input_var = Variable(torch.FloatTensor(input_np))
    from pytorch2keras.converter import pytorch_to_keras
    # we should specify shape of the input tensor
    output = model(input_var)
    k_model = pytorch_to_keras(model, input_var, (3, 64, 64,), verbose=True)

    error = check_error(output, k_model, input_np)
    if max_error < error:
        max_error = error

    print('Max error: {0}'.format(max_error))
    a=2/0
    """
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        w_optim, config.epochs, eta_min=config.w_lr_min)
    architect = Architect(model, config.w_momentum, config.w_weight_decay)

    #model  = torch.load('/content/pt.darts/searchs/custom/checkpoint.pth.tar')

    #print("Loaded!")
    # training loop
    best_top1 = 0.
    best_top_overall = -999
    config.epochs = 300  #BUG, config epochs ta com algum erro
    for epoch in range(config.epochs):
        lr_scheduler.step()
        lr = lr_scheduler.get_lr()[0]

        model.print_alphas(logger)

        print("###################TRAINING#########################")
        # training
        #sample rs arch
        arch = sample_arch(model)
        #import pickle
        #arch = pickle.load( open( "best_arch.p", "rb" ) )
        train(train_loader, valid_loader, model, arch, w_optim, alpha_optim,
              lr, epoch)
        print("###################END TRAINING#########################")

        # validation
        cur_step = (epoch + 1) * len(train_loader)
        print("###################VALID#########################")
        top1, top_overall, _, _ = validate(valid_loader,
                                           model,
                                           arch,
                                           epoch,
                                           cur_step,
                                           overall=True)
        print("###################END VALID#########################")

        # test
        print("###################TEST#########################")
        _, _, preds, targets = validate(test_loader,
                                        model,
                                        arch,
                                        epoch,
                                        cur_step,
                                        overall=True,
                                        debug=True)
        s = [preds, targets]
        import pickle
        pickle.dump(s, open("predictions_" + str(epoch + 1) + ".p", "wb"))
        #print("predictions: ",preds)
        #print("targets:",targets)
        print("###################END TEST#########################")

        # log
        # genotype
        #print("Model Alpha:",model.alpha_normal)
        genotype = model.genotype()
        logger.info("genotype = {}".format(genotype))

        # genotype as a image
        plot_path = os.path.join(config.plot_path,
                                 "EP{:02d}".format(epoch + 1))
        caption = "Epoch {}".format(epoch + 1)
        print("Genotype normal:", genotype.normal)
        plot(genotype.normal, plot_path + "-normal", caption)
        plot(genotype.reduce, plot_path + "-reduce", caption)

        # save
        if best_top1 < top1:
            best_top1 = top1
            best_genotype = genotype
            best_arch = arch
            is_best = True
            import pickle
            pickle.dump(best_arch, open("best_arch.p", "wb"))
            print('best_arch:', best_arch)
            print("saved!")
        else:
            is_best = False
        #save best overall(macro avg of f1 prec and recall)
        if (best_top_overall < top_overall):
            best_top_overall = top_overall
            best_genotype_overall = genotype
            is_best_overall = True
        else:
            is_best_overall = False

        utils.save_checkpoint(model, epoch, w_optim, alpha_optim, net_crit,
                              config.path, is_best, is_best_overall)

    logger.info("Final best Prec@1 = {:.4%}".format(best_top1))
    logger.info("Best Genotype = {}".format(best_genotype))
    logger.info("Best Genotype Overall = {}".format(best_genotype_overall))
示例#19
0
class SearchModel():
    def __init__(self):
        self.config = SearchConfig()
        self.writer = None
        if self.config.tb_dir != "":
            from torch.utils.tensorboard import SummaryWriter
            self.writer = SummaryWriter(self.config.tb_dir, flush_secs=20)
        init_gpu_params(self.config)
        set_seed(self.config)
        self.logger = FileLogger('./log', self.config.is_master,
                                 self.config.is_master)
        self.load_data()
        self.logger.info(self.config)
        self.model = SearchCNNController(self.config, self.n_classes,
                                         self.output_mode)
        self.load_model()
        self.init_kd_component()
        if self.config.n_gpu > 0:
            self.model.to(device)
        if self.config.n_gpu > 1:
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[self.config.local_rank],
                find_unused_parameters=True)
        self.model_to_print = self.model if self.config.multi_gpu is False else self.model.module
        self.architect = Architect(self.model, self.teacher_model, self.config,
                                   self.emd_tool)
        mb_params = param_size(self.model)
        self.logger.info("Model size = {:.3f} MB".format(mb_params))
        self.eval_result_map = []
        self.init_optim()

    def init_kd_component(self):
        from transformers import BertForSequenceClassification
        self.teacher_model, self.emd_tool = None, None
        if self.config.use_kd:
            self.teacher_model = BertForSequenceClassification.from_pretrained(
                self.config.teacher_model, return_dict=False)
            self.teacher_model = self.teacher_model.to(device)
            self.teacher_model.eval()
            if self.config.use_emd:
                self.emd_tool = Emd_Evaluator(
                    self.config.layers,
                    12,
                    device,
                    weight_rate=self.config.weight_rate,
                    add_softmax=self.config.add_softmax)

    def load_data(self):
        # set seed
        if self.config.seed is not None:
            np.random.seed(self.config.seed)
            torch.manual_seed(self.config.seed)
            torch.cuda.manual_seed_all(self.config.seed)
        torch.backends.cudnn.benchmark = True
        self.task_name = self.config.datasets
        self.train_loader, self.arch_loader, self.eval_loader, _, self.output_mode, self.n_classes, self.config, self.eval_sids = load_glue_dataset(
            self.config)
        self.logger.info(f"train_loader length {len(self.train_loader)}")

    def init_optim(self):
        no_decay = ["bias", "LayerNorm.weight"]
        self.w_optim = torch.optim.SGD([
            p for n, p in self.model.named_parameters() if not any(
                nd in n
                for nd in no_decay) and p.requires_grad and 'alpha' not in n
        ],
                                       self.config.w_lr,
                                       momentum=self.config.w_momentum,
                                       weight_decay=self.config.w_weight_decay)
        if self.config.alpha_optim.lower() == 'adam':
            self.alpha_optim = torch.optim.Adam(
                [
                    p for n, p in self.model.named_parameters()
                    if not any(nd in n for nd in no_decay) and p.requires_grad
                    and 'alpha' in n
                ],
                self.config.alpha_lr,
                weight_decay=self.config.alpha_weight_decay)
        elif self.config.alpha_optim.lower() == 'sgd':
            self.alpha_optim = torch.optim.SGD(
                [
                    p for n, p in self.model.named_parameters()
                    if not any(nd in n for nd in no_decay) and p.requires_grad
                    and 'alpha' in n
                ],
                self.config.alpha_lr,
                weight_decay=self.config.alpha_weight_decay)
        else:
            raise NotImplementedError("no such optimizer")

        self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.w_optim, self.config.epochs, eta_min=self.config.w_lr_min)

    def load_model(self):
        if self.config.restore != "":
            old_params_dict = dict()
            for k, v in self.model.named_parameters():
                old_params_dict[k] = v
            self.model.load_state_dict(torch.load(self.config.restore),
                                       strict=False)
            for k, v in self.model.named_parameters():
                if torch.sum(v) != torch.sum(old_params_dict[k]):
                    print(k + " not restore")
            del old_params_dict
        else:
            load_embedding_weight(
                self.model,
                'teacher_utils/bert_base_uncased/pytorch_model.bin', True,
                device)

    def save_checkpoint(self,
                        dump_path,
                        checkpoint_name: str = "checkpoint.pth"):
        """
        Save the current state. Only by the master process.
        """
        if not self.is_master:
            return
        mdl_to_save = self.model.module if hasattr(self.model,
                                                   "module") else self.model
        state_dict = mdl_to_save.state_dict()
        state_dict = {k: v for k, v in state_dict.items() if 'alpha' in k}
        torch.save(state_dict, os.path.join(dump_path, checkpoint_name))

    def main(self):
        # training loop
        best_top1 = 0.
        is_best = False
        for epoch in range(self.config.epochs):
            lr = self.lr_scheduler.get_last_lr()[-1]

            self.logger.info("Epoch {}".format(epoch))
            self.logger.info("Learning Rate {}".format(lr))
            start_time = time.time()

            # for k, v in self.model.named_parameters():
            #     print(k, torch.sum(v).item())

            if self.config.train_mode == 'sep':
                self.train_sep(lr, epoch)
            else:
                self.train(lr, epoch)
            self.logger.info(
                "Current epoch training time = {}".format(time.time() -
                                                          start_time))

            self.model_to_print.print_alphas(self.logger)
            self.lr_scheduler.step()

            self.logger.info('valid')
            cur_step = (epoch + 1) * len(self.train_loader)

            dev_start_time = time.time()
            top1 = self.validate(epoch, cur_step, "val")
            self.logger.info(
                "Current epoch vaildation time = {}".format(time.time() -
                                                            dev_start_time))
            self.logger.info(
                "Current epoch Total time = {}".format(time.time() -
                                                       start_time))

            # genotype
            genotypes = self.model_to_print.genotype()
            # if config.is_master:
            self.logger.info("========genotype========\n" + str(genotypes))
            # # save
            is_best = best_top1 <= top1
            if is_best:
                best_top1 = top1
                best_genotype = genotypes
            self.logger.info("Present best Prec@1 = {:.4%}".format(best_top1))
        if self.config.tb_dir != "":
            self.writer.close()
        # self.logger.info("Final best Prec@1 = " + str(best_top1))
        # self.logger.info("Best Genotype = " + str(best_genotype))
        self.logger.info("==========TRAINING_RESULT============")
        for x in self.eval_result_map:
            self.logger.info(x)
        evals = [x['eval_result'] for x in self.eval_result_map]
        evals[0] = -1
        best_ep = evals.index(max(evals))
        self.logger.info("==========BEST_RESULT============")
        self.logger.info(self.eval_result_map[best_ep])

    def train(self, lr, epoch):
        top1 = AverageMeter()
        losses = AverageMeter()

        self.model.train()
        total_num_step = len(self.train_loader)
        cur_step = epoch * len(self.train_loader)
        valid_iter = iter(self.arch_loader)

        point = [
            int(total_num_step * i / self.config.train_eval_time)
            for i in range(1, self.config.train_eval_time + 1)
        ][:-1]

        for step, data in enumerate(self.train_loader):
            trn_X, trn_y = bert_batch_split(data, self.config.local_rank,
                                            device)
            try:
                v_data = next(valid_iter)
            except StopIteration:
                valid_iter = iter(self.arch_loader)
                v_data = next(valid_iter)
            val_X, val_y = bert_batch_split(v_data, self.config.local_rank,
                                            device)

            trn_t, val_t = None, None
            if self.config.use_kd:
                with torch.no_grad():
                    teacher_logits, teacher_reps = self.teacher_model(
                        input_ids=trn_X[0],
                        attention_mask=trn_X[1],
                        token_type_ids=trn_X[2])
                    trn_t = (teacher_logits, teacher_reps)

                    if self.config.one_step is not True:
                        v_teacher_logits, v_teacher_reps = self.teacher_model(
                            input_ids=val_X[0],
                            attention_mask=val_X[1],
                            token_type_ids=val_X[2])
                        val_t = (v_teacher_logits, v_teacher_reps)

            N = trn_X[0].size(0)

            self.alpha_optim.zero_grad()
            self.architect.unrolled_backward(trn_X, trn_y, val_X, val_y, trn_t,
                                             val_t, lr, self.w_optim)
            self.alpha_optim.step()
            if self.config.multi_gpu:
                torch.distributed.barrier()
            self.w_optim.zero_grad()

            logits = self.model(trn_X)

            if self.config.use_emd:
                logits, s_layer_out = logits
            loss = self.model_to_print.crit(logits, trn_y)

            loss.backward()

            # gradient clipping
            # if not self.config.alpha_only:
            clip = clip_grad_norm_(self.model_to_print.weights(),
                                   self.config.w_grad_clip)
            self.w_optim.step()
            # if self.config.one_step and update_alpha:
            #     self.alpha_optim.step()
            if self.config.tb_dir != "":
                ds, ds2 = self.model.format_alphas()
                for layer_index, dsi in enumerate(ds):
                    self.writer.add_scalars(f'layer-{layer_index}-alpha',
                                            dsi,
                                            global_step=cur_step)
                for layer_index, dsi in enumerate(ds2):
                    self.writer.add_scalars(
                        f'layer-{layer_index}-softmax_alpha',
                        dsi,
                        global_step=cur_step)
                self.writer.add_scalar('loss', loss, global_step=cur_step)
                # self.writer.add_scalar("EMD", rep_loss, global_step=cur_step)
                self.writer.add_scalar("l1 loss",
                                       l1_loss,
                                       global_step=cur_step)

            preds = logits.detach().cpu().numpy()
            result, train_acc = get_acc_from_pred(self.output_mode,
                                                  self.task_name, preds,
                                                  trn_y.detach().cpu().numpy())

            losses.update(loss.item(), N)
            top1.update(train_acc, N)
            # model.print_alphas(logger)

            if self.config.eval_during_train:
                if step + 1 in point:
                    self.model_to_print.print_alphas(self.logger)
                    self.logger.info(
                        "CURRENT Training Step [{:02d}/{:02d}] ".format(
                            step, total_num_step))
                    self.validate(epoch, cur_step, mode="train_dev")
                    genotypes = self.model_to_print.genotype()
                    self.logger.info("========genotype========\n" +
                                     str(genotypes))

            if step % self.config.print_freq == 0 or step == total_num_step - 1:
                self.logger.info(
                    "Train: , [{:2d}/{}] Step {:03d}/{:03d} Loss {:.3f}, Prec@(1,5) {top1.avg:.1%}"
                    .format(epoch + 1,
                            self.config.epochs,
                            step,
                            total_num_step - 1,
                            losses.avg,
                            top1=top1))
            cur_step += 1
        self.logger.info("{:.4%}".format(top1.avg))

    def train_sep(self, lr, epoch):
        top1 = AverageMeter()
        losses = AverageMeter()
        self.model_to_print.train()
        total_num_step = len(self.train_loader)
        self.logger.info(total_num_step)
        cur_step = epoch * len(self.train_loader)
        # valid_iter = iter(self.arch_loader)
        train_component = False
        point = [
            int(total_num_step * i / self.config.train_eval_time)
            for i in range(1, self.config.train_eval_time + 1)
        ][:-1]
        self.logger.info(f"TRAINING ALPHA: {train_component}")
        for step, data in enumerate(self.train_loader):
            trn_X, trn_y = bert_batch_split(data, self.config.local_rank,
                                            device)
            N = trn_X[0].size(0)

            if self.config.multi_gpu:
                torch.distributed.barrier()
            loss = 0.0
            self.alpha_optim.zero_grad()
            self.w_optim.zero_grad()
            logits = self.model_to_print(trn_X)

            if self.config.use_emd:
                logits, s_layer_out = logits

            if epoch % 2 == 1:
                if self.config.use_emd:
                    with torch.no_grad():
                        teacher_logits, teacher_reps = self.teacher_model(
                            input_ids=trn_X[0],
                            attention_mask=trn_X[1],
                            token_type_ids=trn_X[2])
                    if self.config.hidn2attn:
                        s_layer_out = convert_to_attn(s_layer_out, trn_X[1])
                        teacher_reps = convert_to_attn(teacher_reps, trn_X[1])
                    if self.config.skip_mapping:
                        rep_loss = 0
                        teacher_reps = teacher_reps[1:][2::3]
                        for s_layerout, teacher_rep in zip(
                                s_layer_out, teacher_reps):
                            rep_loss += nn.MSELoss()(s_layerout, teacher_rep)
                    else:
                        rep_loss, flow, distance = self.emd_tool.loss(
                            s_layer_out, teacher_reps, return_distance=True)
                        if self.config.update_emd:
                            self.emd_tool.update_weight(flow, distance)
                else:
                    rep_loss = 0.0
                loss = rep_loss * self.config.emd_rate + self.model_to_print.crit(
                    logits, trn_y) * self.config.alpha_ac_rate
                loss.backward()
                self.alpha_optim.step()
            else:
                loss = self.model_to_print.crit(logits, trn_y)
                loss.backward()
                # gradient clipping
                clip = clip_grad_norm_(self.model_to_print.weights(),
                                       self.config.w_grad_clip)
                self.w_optim.step()
            preds = logits.detach().cpu().numpy()
            result, train_acc = get_acc_from_pred(self.output_mode,
                                                  self.task_name, preds,
                                                  trn_y.detach().cpu().numpy())
            losses.update(loss.item(), N)
            top1.update(train_acc, N)

            if self.config.eval_during_train and self.config.local_rank == 0:
                if step + 1 in point:
                    self.model_to_print.print_alphas(self.logger)
                    self.logger.info(
                        "CURRENT Training Step [{:02d}/{:02d}] ".format(
                            step, total_num_step))
                    self.validate(epoch, step, mode="train_dev")
                    genotypes = self.model_to_print.genotype()
                    self.logger.info("========genotype========\n" +
                                     str(genotypes))
                    self.model.train()
            if self.config.multi_gpu:
                torch.distributed.barrier()

            if step % 50 == 0 and self.config.local_rank == 0:
                self.logger.info(
                    "Train: , [{:2d}/{}] Step {:03d}/{:03d} Loss {:.3f}, Prec@(1,5) {top1.avg:.1%}"
                    .format(epoch + 1,
                            self.config.epochs,
                            step,
                            total_num_step - 1,
                            losses.avg,
                            top1=top1))
                if epoch % self.config.alpha_ep != 0 and self.config.update_emd and self.config.use_emd:
                    self.logger.info("s weight:{}".format(
                        self.emd_tool.s_weight))
                    self.logger.info("t weight:{}".format(
                        self.emd_tool.t_weight))
            cur_step += 1
        self.logger.info("{:.4%}".format(top1.avg))

    def validate(self, epoch, cur_step, mode="dev"):
        eval_labels = []
        preds = []
        self.model_to_print.eval()

        total_loss, total_emd_loss = 0, 0
        task_name = self.task_name
        with torch.no_grad():
            for step, data in enumerate(self.eval_loader):
                X, y = bert_batch_split(data, self.config.local_rank, device)
                N = X[0].size(0)
                logits = self.model(X, train=False)
                rep_loss = 0
                if self.config.use_emd:
                    logits, s_layer_out = logits
                    teacher_logits, teacher_reps = self.teacher_model(
                        input_ids=X[0],
                        attention_mask=X[1],
                        token_type_ids=X[2])
                    if self.config.hidn2attn:
                        s_layer_out = convert_to_attn(s_layer_out, X[1])
                        teacher_reps = convert_to_attn(teacher_reps, X[1])
                    rep_loss, flow, distance = self.emd_tool.loss(
                        s_layer_out, teacher_reps, return_distance=True)
                    total_emd_loss += rep_loss.item()
                loss = self.model_to_print.crit(logits, y)
                total_loss += loss.item()
                if len(preds) == 0:
                    preds.append(logits.detach().cpu().numpy())
                else:
                    preds[0] = np.append(preds[0],
                                         logits.detach().cpu().numpy(),
                                         axis=0)
                eval_labels.extend(y.detach().cpu().numpy())
            preds = preds[0]

            if self.task_name.lower() == 'wikiqa':
                preds = {
                    "uids": self.eval_sids['dev'],
                    'scores': np.reshape(np.array([softmax(x) for x in preds]),
                                         -1)
                }
                eval_labels = "data/WikiQA/WikiQA-dev.tsv"
                task_name += 'dev'

            result, acc = get_acc_from_pred(self.output_mode, task_name, preds,
                                            eval_labels)

        self.logger.info(
            mode + ": [{:2d}/{}] Final Prec@1 {} Loss {}, EMD loss: {}".format(
                epoch +
                1, self.config.epochs, result, total_loss, total_emd_loss))
        alpha_soft, alpha_ori = self.model_to_print.get_current_alphas()
        self.eval_result_map.append({
            'mode': mode,
            "epoch": epoch,
            "step": cur_step,
            "eval_result": acc,
            "alpha": alpha_soft,
            "alpha_ori": alpha_ori,
            "genotypes": self.model_to_print.genotype()
        })
        return acc