예제 #1
0
def train(train_queue, model, criterion, optimizer, train_logger):
    objs = utils.AverageMeter()
    model.train()

    batches = len(train_queue)
    for step, (input, target) in enumerate(train_queue):
        input = Variable(input.float()).cuda()
        target = Variable(target.float()).cuda(non_blocking=True)

        optimizer.zero_grad()
        logits, logits_aux = model(input)
        loss = criterion(torch.squeeze(logits), target)
        if args.auxiliary:
            loss_aux = criterion(logits_aux, target)
            loss += args.auxiliary_weight * loss_aux
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()

        n = input.size(0)
        objs.update(loss.item(), n)
        utils.log_loss(train_logger, loss.item(), None, 1 / batches)

        if step % args.report_freq == 0:
            logging.info('train %03d %e', step, objs.avg)

    return objs.avg
예제 #2
0
def train(train_queue, valid_queue, model, architect, criterion, optimizer, lr, loggers):
    objs = utils.AverageMeter()

    valid_iter = iter(valid_queue)
    batches = len(train_queue)
    for step, (input, target) in enumerate(train_queue):
        model.train()
        n = input.size(0)
        model.tick(1 / batches)

        input = Variable(input.float(), requires_grad=False).cuda(non_blocking=True)
        target = Variable(target.float(), requires_grad=False).cuda(non_blocking=True)

        input_search, target_search = next(valid_iter)
        input_search = Variable(input_search.float(), requires_grad=False).cuda(non_blocking=True)
        target_search = Variable(target_search.float(), requires_grad=False).cuda(non_blocking=True)

        valid_loss = architect.step(input, target, input_search, target_search, lr, optimizer, unrolled=args.unrolled)
        utils.log_loss(loggers["val"], valid_loss.item(), None, model.clock)

        optimizer.zero_grad()
        logits = model(input)
        loss = criterion(logits, target)

        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()

        objs.update(loss.item(), n)
        utils.log_loss(loggers["train"], loss.item(), None, model.clock)

        if step % args.report_freq == 0:
            logging.info('train %03d %e', step, objs.avg)

    return objs.avg
예제 #3
0
def train(train_queue, valid_iter, model, architect, criterion, optimizer, lr,
          loggers):
    objs = utils.AverageMeter()
    top1 = utils.AverageMeter()

    batches = len(train_queue)
    for step, (input, target) in enumerate(train_queue):
        model.train()
        model.tick(1 / batches)
        n = input.size(0)

        input = Variable(input, requires_grad=False).cuda(non_blocking=True)
        target = Variable(target, requires_grad=False).cuda(non_blocking=True)

        # get a random minibatch from the search queue without replacement
        input_search, target_search = next(valid_iter)
        input_search = Variable(input_search,
                                requires_grad=False).cuda(non_blocking=True)
        target_search = Variable(target_search,
                                 requires_grad=False).cuda(non_blocking=True)

        valid_loss = architect.step(input,
                                    target,
                                    input_search,
                                    target_search,
                                    lr,
                                    optimizer,
                                    unrolled=args.unrolled)
        utils.log_loss(loggers["val"], valid_loss, None, model.clock)

        optimizer.zero_grad()
        logits = model(input)
        loss = criterion(logits, target)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()
        model.mask_alphas()
        model.track_FI()
        model.update_history()

        prec1 = utils.accuracy(logits, target, topk=(1, ))
        objs.update(loss.item(), n)
        top1.update(prec1[0].item(), n)
        utils.log_loss(loggers["train"], loss.item(), prec1[0].item(),
                       model.clock)

        if step % args.report_freq == 0:
            logging.info('train %03d %e %f', step, objs.avg, top1.avg)

        if (step + 1) % args.admm_freq == 0:
            model.update_Z()
            model.update_U()

    return top1.avg, objs.avg
예제 #4
0
def train(train_queue, valid_queue, model, architect, criterion, optimizer,
          loggers):
    objs = utils.AverageMeter()
    top1 = utils.AverageMeter()

    valid_iter = iter(valid_queue)
    batches = len(train_queue)
    for step, (input, target) in enumerate(train_queue):
        model.train()
        n = input.size(0)
        model.tick(1 / batches)

        input = Variable(input, requires_grad=False).cuda(non_blocking=True)
        target = Variable(target, requires_grad=False).cuda(non_blocking=True)

        # get a random minibatch from the search queue with replacement
        input_search, target_search = next(valid_iter)
        input_search = Variable(input_search,
                                requires_grad=False).cuda(non_blocking=True)
        target_search = Variable(target_search,
                                 requires_grad=False).cuda(non_blocking=True)

        architect.step(input_search, target_search)

        optimizer.zero_grad()
        logits = model(input)
        loss = criterion(logits, target)

        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        model.FI_hist.append(
            torch.norm(
                torch.stack([
                    torch.norm(p.grad.detach(), 2.0).cuda()
                    for p in model.parameters() if p.grad is not None
                ]), 2.0)**2)
        if len(model.batchstep) > 0:
            model.batchstep.append(model.batchstep[-1] + 1 / batches)
        else:
            model.batchstep.append(0.0)
        optimizer.step()

        prec1 = utils.accuracy(logits, target, topk=(1, ))
        objs.update(loss.item(), n)
        top1.update(prec1[0].item(), n)
        utils.log_loss(loggers["train"], loss.item(), prec1[0].item(),
                       model.clock)
        model.update_history()

        if step % args.report_freq == 0:
            logging.info('train %03d %e %f', step, objs.avg, top1.avg)

    return top1.avg, objs.avg
예제 #5
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    model = Network(args.init_channels, CIFAR_CLASSES, args.layers, criterion, args.rho, args.ewma)
    model = model.cuda()
    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    optimizer = torch.optim.SGD(
        model.parameters(),
        args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay)

    train_transform, valid_transform = utils._data_transforms_cifar10(args)
    datapath = os.path.join(utils.get_dir(), args.data)
    train_data = dset.CIFAR10(root=datapath, train=True, download=True, transform=train_transform)

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))

    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True, num_workers=2)

    valid_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
        pin_memory=True, num_workers=2)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, int(args.epochs), eta_min=args.learning_rate_min)

    architect = Architect(model, args)

    model.initialize_Z_and_U()

    loggers = {"train": {"loss": [], "acc": [], "step": []},
               "val": {"loss": [], "acc": [], "step": []},
               "infer": {"loss": [], "acc": [], "step": []},
               "ath": {"threshold": [], "step": []},
               "zuth": {"threshold": [], "step": []},
               "astep": [],
               "zustep": []}

    if args.constant_alpha_threshold < 0:
        alpha_threshold = args.init_alpha_threshold
    else:
        alpha_threshold = args.constant_alpha_threshold
    zu_threshold = args.init_zu_threshold
    alpha_counter = 0
    ewma = -1

    for epoch in range(args.epochs):
        valid_iter = iter(valid_queue)
        model.clear_U()

        scheduler.step()
        lr = scheduler.get_last_lr()[0]

        logging.info('epoch %d lr %e', epoch, lr)

        genotype = model.genotype()
        logging.info('genotype = %s', genotype)

        print(torch.clamp(model.alphas_normal, min=0.1, max=1.0))
        print(torch.clamp(model.alphas_reduce, min=0.1, max=1.0))

        # training
        train_acc, train_obj, alpha_threshold, zu_threshold, alpha_counter, ewma = train(train_queue, valid_iter, model,
                                                                                         architect, criterion,
                                                                                         optimizer, lr,
                                                                                         loggers, alpha_threshold,
                                                                                         zu_threshold, alpha_counter,
                                                                                         ewma,
                                                                                         args)
        logging.info('train_acc %f', train_acc)

        # validation
        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        utils.log_loss(loggers["infer"], valid_obj, valid_acc, model.clock)
        logging.info('valid_acc %f', valid_acc)

        utils.plot_loss_acc(loggers, args.save)

        # model.update_history()

        utils.save_file(recoder=model.alphas_normal_history, path=os.path.join(args.save, 'normalalpha'),
                        steps=loggers["train"]["step"])
        utils.save_file(recoder=model.alphas_reduce_history, path=os.path.join(args.save, 'reducealpha'),
                        steps=loggers["train"]["step"])
        utils.save_file(recoder=model.FI_normal_history, path=os.path.join(args.save, 'normalFI'),
                        steps=loggers["train"]["step"])
        utils.save_file(recoder=model.FI_reduce_history, path=os.path.join(args.save, 'reduceFI'),
                        steps=loggers["train"]["step"])

        scaled_FI_normal = scale(model.FI_normal_history, model.alphas_normal_history)
        scaled_FI_reduce = scale(model.FI_reduce_history, model.alphas_reduce_history)
        utils.save_file(recoder=scaled_FI_normal, path=os.path.join(args.save, 'normalFIscaled'),
                        steps=loggers["train"]["step"])
        utils.save_file(recoder=scaled_FI_reduce, path=os.path.join(args.save, 'reduceFIscaled'),
                        steps=loggers["train"]["step"])

        utils.plot_FI(loggers["train"]["step"], model.FI_history, args.save, "FI", loggers["ath"], loggers['astep'])
        utils.plot_FI(loggers["train"]["step"], model.FI_ewma_history, args.save, "FI_ewma", loggers["ath"],
                      loggers['astep'])
        utils.plot_FI(model.FI_alpha_history_step, model.FI_alpha_history, args.save, "FI_alpha", loggers["zuth"],
                      loggers['zustep'])

        utils.save(model, os.path.join(args.save, 'weights.pt'))

    genotype = model.genotype()
    logging.info('genotype = %s', genotype)

    f = open(os.path.join(args.save, 'genotype.txt'), "w")
    f.write(str(genotype))
    f.close()
예제 #6
0
def train(train_queue, valid_iter, model, architect, criterion, optimizer, lr, loggers, alpha_threshold, zu_threshold,
          alpha_counter, ewma, args):
    objs = utils.AverageMeter()
    top1 = utils.AverageMeter()

    batches = len(train_queue)
    for step, (input, target) in enumerate(train_queue):
        model.train()
        n = input.size(0)
        model.tick(1 / batches)
        alpha_step = False

        print("FI: ", model.FI, "FI_ewma: ", model.FI_ewma, " alpha_threshold: ", alpha_threshold)
        loggers["ath"]["threshold"].append(alpha_threshold)
        loggers["ath"]["step"].append(model.clock)
        if (model.FI_ewma > 0.0) & (model.FI_ewma < alpha_threshold):
            print("alpha step")
            # get a random minibatch from the search queue without replacement
            input_search, target_search = next(valid_iter)
            input_search = Variable(input_search, requires_grad=False).cuda(non_blocking=True)
            target_search = Variable(target_search, requires_grad=False).cuda(non_blocking=True)

            valid_loss = architect.step(input, target, input_search, target_search, lr, optimizer,
                                        unrolled=args.unrolled)
            utils.log_loss(loggers["val"], valid_loss, None, model.clock)
            # alpha_threshold = args.init_alpha_threshold
            if args.constant_alpha_threshold < 0:
                alpha_threshold *= 0.5
            alpha_step = True
            alpha_counter += 1
            loggers["astep"].append(model.clock)
        elif args.constant_alpha_threshold < 0:
            alpha_threshold *= 1.1

        input = Variable(input, requires_grad=False).cuda(non_blocking=True)
        target = Variable(target, requires_grad=False).cuda(non_blocking=True)

        optimizer.zero_grad()
        logits = model(input)
        loss = criterion(logits, target)
        loss.backward()
        model.track_FI(alpha_step)
        nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()
        model.mask_alphas()

        model.update_history()

        prec1 = utils.accuracy(logits, target, topk=(1,))
        objs.update(loss.item(), n)
        top1.update(prec1[0].item(), n)
        utils.log_loss(loggers["train"], loss.item(), prec1[0].item(), model.clock)

        if step % args.report_freq == 0:
            logging.info('train %03d %e %f', step, objs.avg, top1.avg)

        if args.scheduled_zu:
            print("FI_alpha: ", model.FI_alpha, " zu_threshold: ", zu_threshold)
            loggers["zuth"]["threshold"].append(zu_threshold)
            loggers["zuth"]["step"].append(model.clock)
            if alpha_step & (model.FI_alpha > 0.0) & (model.FI_alpha < zu_threshold):
                print("zu step")
                model.update_Z()
                model.update_U()
                # zu_threshold = args.init_zu_threshold
                zu_threshold *= 0.5
                loggers["zustep"].append(model.clock)
                alpha_counter = 0
                # reset alpha threshold?
            elif alpha_step:
                zu_threshold *= 1.1
        else:
            if (alpha_counter + 1) % args.admm_freq == 0:
                model.update_Z()
                model.update_U()
                loggers["zustep"].append(model.clock)
                alpha_counter = 0

    utils.log_loss(loggers["val"], valid_loss, None, model.clock)
    return top1.avg, objs.avg, alpha_threshold, zu_threshold, alpha_counter, ewma
예제 #7
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    criterion = nn.MSELoss()
    criterion = criterion.cuda()
    model = Network(args.init_channels, 1, args.layers, criterion, input_channels=4)
    model = model.cuda()
    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    optimizer = torch.optim.SGD(
        model.parameters(),
        args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay)

    # dataset = utils.BathymetryDataset(args, "guyane/guyane.csv")
    # dataset.add(args, "saint_louis/saint_louis.csv")

    dataset = utils.BathymetryDataset(args, "../mixed_train.csv", to_filter=False)
    dataset.add(args, "../mixed_validation.csv", to_balance=False)

    trains, vals = dataset.get_subset_indices(args.train_portion)

    train_queue = torch.utils.data.DataLoader(
        dataset, batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(trains),
        pin_memory=True, num_workers=2)

    valid_queue = torch.utils.data.DataLoader(
        dataset, batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(vals),
        pin_memory=True, num_workers=2)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, int(args.epochs), eta_min=args.learning_rate_min)

    architect = Architect(model, args)

    loggers = {"train": {"loss": [], "step": []}, "val": {"loss": [], "step": []}, "infer": {"loss": [], "step": []}}

    for epoch in range(args.epochs):
        scheduler.step()
        lr = scheduler.get_last_lr()[0]

        logging.info('epoch %d lr %e', epoch, lr)

        genotype = model.genotype()
        logging.info('genotype = %s', genotype)

        print(F.softmax(model.alphas_normal, dim=-1))
        print(F.softmax(model.alphas_reduce, dim=-1))

        # training
        _ = train(train_queue, valid_queue, model, architect, criterion, optimizer, lr, loggers)

        # validation
        infer_loss = infer(valid_queue, model, criterion)
        utils.log_loss(loggers["infer"], infer_loss, None, model.clock)

        utils.plot_loss_acc(loggers, args.save)

        model.update_history()

        utils.save_file(recoder=model.alphas_normal_history, path=os.path.join(args.save, 'normal'))
        utils.save_file(recoder=model.alphas_reduce_history, path=os.path.join(args.save, 'reduce'))

        utils.save(model, os.path.join(args.save, 'weights.pt'))

    print(F.softmax(model.alphas_normal, dim=-1))
    print(F.softmax(model.alphas_reduce, dim=-1))

    np.save(os.path.join(os.path.join(args.save, 'normal_weight.npy')),
            F.softmax(model.alphas_normal, dim=-1).data.cpu().numpy())
    np.save(os.path.join(os.path.join(args.save, 'reduce_weight.npy')),
            F.softmax(model.alphas_reduce, dim=-1).data.cpu().numpy())

    genotype = model.genotype()
    logging.info('genotype = %s', genotype)

    f = open(os.path.join(args.save, 'genotype.txt'), "w")
    f.write(str(genotype))
    f.close()
예제 #8
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    genotype_path = os.path.join(utils.get_dir(), args.genotype_path,
                                 'genotype.txt')
    if os.path.isfile(genotype_path):
        with open(genotype_path, "r") as f:
            geno_raw = f.read()
            genotype = eval(geno_raw)
    else:
        genotype = eval("genotypes.%s" % args.arch)

    f = open(os.path.join(args.save, 'genotype.txt'), "w")
    f.write(str(genotype))
    f.close()

    model = Network(args.init_channels,
                    1,
                    args.layers,
                    args.auxiliary,
                    genotype,
                    input_channels=4)
    model = model.cuda()

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    criterion = nn.MSELoss()
    criterion = criterion.cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # train_transform, valid_transform = utils._data_transforms_cifar10(args)
    # datapath = os.path.join(utils.get_dir(), args.data)
    # train_data = dset.CIFAR10(root=datapath, train=True, download=True, transform=train_transform)
    # valid_data = dset.CIFAR10(root=datapath, train=False, download=True, transform=valid_transform)
    train_data = utils.BathymetryDataset(args,
                                         "../mixed_train.csv",
                                         to_filter=False)
    valid_data = utils.BathymetryDataset(args,
                                         "../mixed_validation.csv",
                                         to_filter=False)

    train_queue = torch.utils.data.DataLoader(train_data,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              pin_memory=True,
                                              num_workers=2)

    valid_queue = torch.utils.data.DataLoader(valid_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=2)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, args.epochs)

    loggers = {
        "train": {
            "loss": [],
            "step": []
        },
        "val": {
            "loss": [],
            "step": []
        }
    }

    for epoch in range(args.epochs):
        scheduler.step()
        logging.info('epoch %d lr %e', epoch, scheduler.get_last_lr()[0])
        model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

        _ = train(train_queue, model, criterion, optimizer, loggers["train"])

        infer_loss = infer(valid_queue, model, criterion)
        utils.log_loss(loggers["val"], infer_loss, None, 1)

        utils.plot_loss_acc(loggers, args.save)

        utils.save(model, os.path.join(args.save, 'weights.pt'))
        if (epoch + 1) % 50 == 0:
            utils.save(
                model,
                os.path.join(args.save,
                             'checkpoint' + str(epoch) + 'weights.pt'))
def train(train_queue, valid_queue, model, architect, criterion, optimizer,
          loggers, alpha_threshold, alpha_counter, ewma, args):
    objs = utils.AverageMeter()
    top1 = utils.AverageMeter()
    valid_iter = iter(valid_queue)
    # print("valid len:", len(valid_queue))

    batches = len(train_queue)
    for step, (input, target) in enumerate(train_queue):
        model.train()
        n = input.size(0)
        model.tick(1 / batches)

        valid_loss = 0.0

        loggers["ath"]["threshold"].append(alpha_threshold)
        loggers["ath"]["step"].append(model.clock)
        if (not args.dyno_schedule and (step + 1) % int(args.schedfreq)
                == 0) or (args.dyno_schedule and model.FI_ewma > 0.0
                          and model.FI_ewma < alpha_threshold):
            # print("alpha step")
            try:
                input_search, target_search = next(valid_iter)
            except StopIteration:
                print("reset valid iter")
                valid_iter = iter(valid_queue)
                input_search, target_search = next(valid_iter)
            input_search = Variable(
                input_search, requires_grad=False).cuda(non_blocking=True)
            target_search = Variable(
                target_search, requires_grad=False).cuda(non_blocking=True)
            if args.gpu != -1:
                input_search = input_search.cuda(non_blocking=True)
                target_search = target_search.cuda(non_blocking=True)

            valid_loss = architect.step(input_search, target_search)
            utils.log_loss(loggers["val"], valid_loss, None, model.clock)
            if args.dyno_schedule:
                alpha_threshold *= args.threshold_divider
            alpha_counter += 1
            loggers["astep"].append(model.clock)
        elif args.dyno_schedule:
            alpha_threshold *= args.threshold_multiplier

        input = Variable(input, requires_grad=False).cuda(non_blocking=True)
        target = Variable(target, requires_grad=False).cuda(non_blocking=True)
        if args.gpu != -1:
            input = input.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)

        optimizer.zero_grad()
        logits = model(input)
        loss = criterion(logits, target)
        loss.backward()
        model.track_FI()
        nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()
        model.mask_alphas()

        model.update_history()

        prec1 = utils.accuracy(logits, target, topk=(1, ))
        objs.update(loss.detach().item(), n)
        top1.update(prec1[0].item(), n)
        utils.log_loss(loggers["train"], loss, prec1[0].item(), model.clock)

        if step % args.report_freq == 0:
            logging.info('train %03d %e %f', step, objs.avg, top1.avg)

        if (args.reg == "admm") & ((alpha_counter + 1) % args.admm_freq == 0):
            model.update_Z()
            model.update_U()
            loggers["zustep"].append(model.clock)
            alpha_counter = 0

    utils.log_loss(loggers["val"], valid_loss, None, model.clock)
    return top1.avg, objs.avg, alpha_threshold, alpha_counter, ewma
def main():
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.gpu != -1:
        if not torch.cuda.is_available():
            logging.info('no gpu device available')
            sys.exit(1)
        torch.cuda.set_device(args.gpu)
        cudnn.benchmark = True
        cudnn.enabled = True
        torch.cuda.manual_seed(args.seed)
        logging.info('gpu device = %d' % args.gpu)
    else:
        logging.info('using cpu')

    if args.dyno_schedule:
        args.threshold_divider = np.exp(-np.log(args.threshold_multiplier) *
                                        args.schedfreq)
        print(
            args.threshold_divider, -np.log(args.threshold_multiplier) /
            np.log(args.threshold_divider))
    if args.dyno_split:
        args.train_portion = 1 - 1 / (1 + args.schedfreq)

    logging.info("args = %s", args)

    criterion = nn.CrossEntropyLoss()
    if args.gpu != -1:
        criterion = criterion.cuda()
    model = Network(args.init_channels,
                    CIFAR_CLASSES,
                    args.layers,
                    criterion,
                    args.rho,
                    args.crb,
                    args.epochs,
                    args.gpu,
                    ewma=args.ewma,
                    reg=args.reg)
    if args.gpu != -1:
        model = model.cuda()
    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    datapath = os.path.join(utils.get_dir(), args.data)
    if args.task == "CIFAR100cf":
        train_transform, valid_transform = utils._data_transforms_cifar100(
            args)
        train_data = utils.CIFAR100C2F(root=datapath,
                                       train=True,
                                       download=True,
                                       transform=train_transform)
        num_train = len(train_data)
        indices = list(range(num_train))

        split = int(np.floor(args.train_portion * len(indices)))

        orig_num_train = len(indices[:split])
        orig_num_valid = len(indices[split:num_train])

        train_indices = train_data.filter_by_fine(args.train_filter,
                                                  indices[:split])
        valid_indices = train_data.filter_by_fine(args.valid_filter,
                                                  indices[split:num_train])

        train_queue = torch.utils.data.DataLoader(
            train_data,
            batch_size=args.batch_size,
            sampler=utils.FillingSubsetRandomSampler(train_indices,
                                                     orig_num_train,
                                                     reshuffle=True),
            pin_memory=True,
            num_workers=2)

        valid_queue = torch.utils.data.DataLoader(
            train_data,
            batch_size=args.batch_size,
            sampler=utils.FillingSubsetRandomSampler(valid_indices,
                                                     orig_num_valid,
                                                     reshuffle=True),
            pin_memory=True,
            num_workers=2)
        # TODO: extend each epoch or multiply number of epochs by 20%*args.class_filter
    elif args.task == "CIFAR100split":
        train_transform, valid_transform = utils._data_transforms_cifar100(
            args)
        train_data = utils.CIFAR100C2F(root=datapath,
                                       train=True,
                                       download=True,
                                       transform=train_transform)
        if not args.evensplit:
            train_indices, valid_indices = train_data.split(args.train_portion)
        else:
            num_train = len(train_data)
            indices = list(range(num_train))

            split = int(np.floor(args.train_portion * num_train))

            train_indices = indices[:split]
            valid_indices = indices[split:num_train]

        train_queue = torch.utils.data.DataLoader(
            train_data,
            batch_size=args.batch_size,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(
                train_indices),
            pin_memory=True,
            num_workers=2)

        valid_queue = torch.utils.data.DataLoader(
            train_data,
            batch_size=args.batch_size,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(
                valid_indices),
            pin_memory=True,
            num_workers=2)
    else:
        if args.task == "CIFAR100":
            train_transform, valid_transform = utils._data_transforms_cifar100(
                args)
            train_data = dset.CIFAR100(root=datapath,
                                       train=True,
                                       download=True,
                                       transform=train_transform)
        else:
            train_transform, valid_transform = utils._data_transforms_cifar10(
                args)
            train_data = dset.CIFAR10(root=datapath,
                                      train=True,
                                      download=True,
                                      transform=train_transform)
        num_train = len(train_data)
        indices = list(range(num_train))

        split = int(np.floor(args.train_portion * num_train))

        train_indices = indices[:split]
        valid_indices = indices[split:num_train]

        train_queue = torch.utils.data.DataLoader(
            train_data,
            batch_size=args.batch_size,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(
                train_indices),
            pin_memory=True,
            num_workers=4)

        valid_queue = torch.utils.data.DataLoader(
            train_data,
            batch_size=args.batch_size,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(
                valid_indices),
            pin_memory=True,
            num_workers=4)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, int(args.epochs), eta_min=args.learning_rate_min)

    architect = Architect(model, args)

    loggers = {
        "train": {
            "loss": [],
            "acc": [],
            "step": []
        },
        "val": {
            "loss": [],
            "acc": [],
            "step": []
        },
        "infer": {
            "loss": [],
            "acc": [],
            "step": []
        },
        "ath": {
            "threshold": [],
            "step": []
        },
        "astep": [],
        "zustep": []
    }

    alpha_threshold = args.init_alpha_threshold
    alpha_counter = 0
    ewma = -1

    for epoch in range(args.epochs):
        scheduler.step()
        lr = scheduler.get_last_lr()[0]

        logging.info('epoch %d lr %e', epoch, lr)

        genotype = model.genotype()
        logging.info('genotype = %s', genotype)
        if args.ckpt_interval > 0 and epoch > 0 and (
                epoch) % args.ckpt_interval == 0:
            logging.info('checkpointing genotype')
            os.mkdir(os.path.join(args.save, 'genotypes', str(epoch)))
            with open(
                    os.path.join(args.save, 'genotypes', str(epoch),
                                 'genotype.txt'), "w") as f:
                f.write(str(genotype))

        print(model.activate(model.alphas_normal))
        print(model.activate(model.alphas_reduce))

        # training
        train_acc, train_obj, alpha_threshold, alpha_counter, ewma = train(
            train_queue, valid_queue, model, architect, criterion, optimizer,
            loggers, alpha_threshold, alpha_counter, ewma, args)
        logging.info('train_acc %f', train_acc)

        # validation
        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        utils.log_loss(loggers["infer"], valid_obj, valid_acc, model.clock)
        logging.info('valid_acc %f', valid_acc)

        utils.plot_loss_acc(loggers, args.save)

        utils.save_file(recoder=model.alphas_normal_history,
                        path=os.path.join(args.save, 'Normalalpha'),
                        steps=loggers["train"]["step"])
        utils.save_file(recoder=model.alphas_reduce_history,
                        path=os.path.join(args.save, 'Reducealpha'),
                        steps=loggers["train"]["step"])

        utils.plot_FI(loggers["train"]["step"], model.FI_history, args.save,
                      "FI", loggers["ath"], loggers['astep'])
        utils.plot_FI(loggers["train"]["step"], model.FI_ewma_history,
                      args.save, "FI_ewma", loggers["ath"], loggers['astep'])

        utils.save(model, os.path.join(args.save, 'weights.pt'))

    genotype = model.genotype()
    logging.info('genotype = %s', genotype)

    f = open(os.path.join(args.save, 'genotype.txt'), "w")
    f.write(str(genotype))
    f.close()