val_loss,
            val_r_sq,
            val_accu
        )

        # print loss
        if epoch % PRINT_EVERY == 0:
            print('='*20 + '\nEpoch %d / %d\n' % (epoch, NUM_EPOCHS) + '='*20)
            print('[%s (%d %.1f%%)]' % (train_utils.time_since(START), epoch, float(epoch) / NUM_EPOCHS * 100))
            print('[%s %0.5f, %s %0.5f, %s %0.5f]'% ('Train Loss: ', train_loss, ' R-sq: ', train_r_sq, ' Accu:', train_accu))
            print('[%s %0.5f, %s %0.5f, %s %0.5f]'% ('Valid Loss: ', val_loss, ' R-sq: ', val_r_sq, ' Accu:', val_accu))

        # save model if best validation loss
        if val_loss < best_val_loss:
            n = file_info + '_best'
            train_utils.save(n, perf_model)
            best_val_loss = val_loss

    print("Saving...")
    train_utils.save(file_info, perf_model, log_parameters)

except KeyboardInterrupt:
    print("Saving before quit...")
    train_utils.save(file_info, perf_model, log_parameters)

# RUN VALIDATION SET ON THE BEST MODEL
# read the best model
filename = file_info + '_best' + '_Reg'
if torch.cuda.is_available():
    perf_model.cuda()
    perf_model.load_state_dict(torch.load('saved/' + filename + '.pt'))
        log_value('val_r_sq', val_r_sq, epoch)
        log_value('train_accu', train_accu, epoch)
        log_value('val_accu', val_accu, epoch)
        log_value('train_accu2', train_accu2, epoch)
        log_value('val_accu2', val_accu2, epoch)
        #####

        # print loss
        if epoch % PRINT_EVERY == 0:
            print('[%s (%d %.1f%%)]' % (train_utils.time_since(START), epoch, float(epoch) / NUM_EPOCHS * 100))
            print('[%s %0.5f, %s %0.5f, %s %0.5f %0.5f]'% ('Train Loss: ', train_loss, ' R-sq: ', train_r_sq, ' Accu:', train_accu, train_accu2))
            print('[%s %0.5f, %s %0.5f, %s %0.5f %0.5f]'% ('Valid Loss: ', val_loss, ' R-sq: ', val_r_sq, ' Accu:', val_accu, val_accu2))
        # save model if best validation loss
        if val_loss.item() < best_val_loss:
            n = NAME + '_best'
            train_utils.save(n, perf_model)
            best_val_loss = val_loss.item()
            best_epoch = epoch
        # store the best r-squared value from training
        if val_r_sq > best_valrsq:
            best_valrsq = val_r_sq
        if best_epoch < epoch - 200:
            break
    print("Saving...")
    train_utils.save(NAME, perf_model)
except KeyboardInterrupt:
    print("Saving before quit...")
    train_utils.save(NAME, perf_model)

print('BEST R^2 VALUE: ' + str(best_valrsq))
예제 #3
0
def main(args):
    s3_bucket = args.run.s3_bucket
    log = os.path.join(os.getcwd(), "log.txt")

    if s3_bucket is not None:
        aws_utils.download_from_s3(log, s3_bucket, log)

    train_utils.set_up_logging(log)

    CIFAR_CLASSES = 10

    if not torch.cuda.is_available():
        logging.info("no gpu device available")
        sys.exit(1)

    try:
        aws_utils.download_from_s3("cnn_genotypes.txt", s3_bucket,
                                   "/tmp/cnn_genotypes.txt")

        with open("/code/nas-theory/cnn/search_spaces/darts/genotypes.py",
                  "a") as f:
            with open("/tmp/cnn_genotypes.txt") as archs:
                f.write("\n")
                for line in archs:
                    if "Genotype" in line:
                        f.write(line)
        print("Downloaded genotypes from aws.")
    except Exception as e:
        print(e)

    # Importing here because need to get the latest genotypes before importing.
    from search_spaces.darts import genotypes

    rng_seed = train_utils.RNGSeed(args.run.seed)

    torch.cuda.set_device(args.run.gpu)
    logging.info("gpu device = %d" % args.run.gpu)
    logging.info("args = %s", args.pretty())

    print(dir(genotypes))

    genotype = eval("genotypes.%s" % args.train.arch)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()

    model = Network(
        args.train.init_channels,
        CIFAR_CLASSES,
        args.train.layers,
        args.train.auxiliary,
        genotype,
    )
    model = model.cuda()

    optimizer, scheduler = train_utils.setup_optimizer(model, args)

    logging.info("param size = %fMB",
                 train_utils.count_parameters_in_MB(model))
    total_params = sum(x.data.nelement() for x in model.parameters())
    logging.info("Model total parameters: {}".format(total_params))

    try:
        start_epochs, _ = train_utils.load(os.getcwd(),
                                           rng_seed,
                                           model,
                                           optimizer,
                                           s3_bucket=s3_bucket)
        scheduler.last_epoch = start_epochs - 1
    except Exception as e:
        print(e)
        start_epochs = 0

    num_train, num_classes, train_queue, valid_queue = train_utils.create_data_queues(
        args, eval_split=True)

    for epoch in range(start_epochs, args.run.epochs):
        logging.info("epoch %d lr %e", epoch, scheduler.get_lr()[0])
        model.drop_path_prob = args.train.drop_path_prob * epoch / args.run.epochs

        train_acc, train_obj = train(args, train_queue, model, criterion,
                                     optimizer)
        logging.info("train_acc %f", train_acc)

        valid_acc, valid_obj = train_utils.infer(
            valid_queue, model, criterion, report_freq=args.run.report_freq)
        logging.info("valid_acc %f", valid_acc)

        train_utils.save(os.getcwd(),
                         epoch + 1,
                         rng_seed,
                         model,
                         optimizer,
                         s3_bucket=s3_bucket)
        scheduler.step()
예제 #4
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    log = os.path.join(args.save, 'log.txt')
    if args.s3_bucket is not None:
        aws_utils.download_from_s3(log, args.s3_bucket, log)

    rng_seed = train_utils.RNGSeed(args.seed, deterministic=False)

    logging.info("args = %s", args)
    #dataset_dir = '/cache/'
    #pre.split_dataset(dataset_dir)
    #sys.exit(1)
    # dataset prepare
    data_dir = os.path.join(args.tmp_data_dir, 'imagenet_search')
    traindir = os.path.join(data_dir, 'train')
    valdir = data_dir = os.path.join(data_dir,  'val')
        
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    #dataset split     
    train_data1 = dset.ImageFolder(traindir, transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    train_data2 = dset.ImageFolder(valdir, transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    valid_data = dset.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))
    num_train = len(train_data1)
    num_val = len(train_data2)
    print('# images to train network: %d' % num_train)
    print('# images to validate network: %d' % num_val)
    
    model = Network(args.init_channels, CLASSES, args.layers, criterion)
    model = torch.nn.DataParallel(model)
    model = model.cuda()
    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    optimizer = torch.optim.SGD(
        model.parameters(),
        args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay)
    architect = Architect(model, criterion, args)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, float(args.epochs), eta_min=args.learning_rate_min)
    history = []

    try:
        start_epochs, history = train_utils.load(
            args.save, rng_seed, model, optimizer, args.s3_bucket
        )
        print(history)
        scheduler.last_epoch = start_epochs - 1
    except Exception as e:
        print(e)
        start_epochs = 0
    
    test_queue = torch.utils.data.DataLoader(
                        valid_data, 
                        batch_size=args.batch_size, 
                        shuffle=False, 
                        pin_memory=True, 
                        num_workers=args.workers)

    train_queue = torch.utils.data.DataLoader(
        train_data1, batch_size=args.batch_size, shuffle=True,
        pin_memory=True, num_workers=args.workers)

    valid_queue = torch.utils.data.DataLoader(
        train_data2, batch_size=args.batch_size, shuffle=True,
        pin_memory=True, num_workers=args.workers)
    if args.debug:
        train_queue = valid_queue
        valid_queue = test_queue

    lr=args.learning_rate
    for epoch in range(start_epochs, args.epochs):
        scheduler.step()
        current_lr = scheduler.get_lr()[0]
        logging.info('Epoch: %d lr: %e', epoch, current_lr)
        if epoch < 5 and args.batch_size > 256:
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr * (epoch + 1) / 5.0
            logging.info('Warming-up Epoch: %d, LR: %e', epoch, lr * (epoch + 1) / 5.0)
            print(optimizer) 
        genotype = model.module.genotype()
        logging.info('genotype = %s', genotype)
        arch_param = model.module.arch_parameters()
        logging.info(arch_param[0])
        logging.info(arch_param[1])
        logging.info(arch_param[2])
        logging.info(arch_param[3])
        # training
        train_acc, train_obj = train(train_queue, valid_queue, model, optimizer, architect, criterion, lr,epoch)
        logging.info('Train_acc %f', train_acc)
        
        # validation
        if epoch>= 47:
            valid_acc, valid_obj = infer(valid_queue, model, criterion)
            #test_acc, test_obj = infer(test_queue, model, criterion)
            logging.info('Valid_acc %f', valid_acc)
            #logging.info('Test_acc %f', test_acc)
        history.append([p.data.cpu().numpy() for p in model.module.arch_parameters()])
        logging.info("saving checkpoint")
        train_utils.save(args.save, epoch+1, rng_seed, model, optimizer, history, args.s3_bucket)

    if args.s3_bucket is not None:
        filename = "cnn_genotypes.txt"
        aws_utils.download_from_s3(filename, args.s3_bucket, filename)

        with open(filename, "a+") as f:
            f.write("\n")
            f.write("{}{} = {}".format('edarts', args.seed, genotype))
        aws_utils.upload_to_s3(filename, args.s3_bucket, filename)
        aws_utils.upload_to_s3(log, args.s3_bucket, log)
예제 #5
0
            # print loss
            if epoch % PRINT_EVERY == 0:
                print('[%s (%d %.1f%%)]' % (train_utils.time_since(START), epoch, float(epoch) / NUM_EPOCHS * 100))
                #print('[%s %0.5f, %s %0.5f, %s %0.5f %0.5f]'% ('Train Loss: ', train_loss, ' R-sq: ', train_r_sq, ' Accu:', train_accu, train_accu2))
                #print('[%s %0.5f, %s %0.5f, %s %0.5f %0.5f]'% ('Valid Loss: ', val_loss, ' R-sq: ', val_r_sq, ' Accu:', val_accu, val_accu2))

                if contrastive:
                    #print('[%s %0.5f, %s %0.5f, %s %0.5f]'%('Train Contrastive Loss: ', loss_contrastive_train,'Train CE Loss:', ce_loss_train, 'Train Accuracy: ', acc_contrastive_train))
                    #print('[%s %0.5f, %s %0.5f, %s %0.5f]'%('Validation Contrastive Loss: ', loss_contrastive_val,'Validation CE Loss:', ce_loss_val, 'Validation Accuracy: ', acc_contrastive_val))
                    print('[%s %0.5f]'%('Train Contrastive Loss: ', loss_contrastive_train))
                    print('[%s %0.5f]'%('Validation Contrastive Loss: ', loss_contrastive_val))
            # save model if best validation accuracy
            if loss_contrastive_val.item() < best_loss_contrastive_val: #acc_contrastive_val > best_acc_contrastive_val: #
                n = 'pc_contrastive_runs/' + NAME + '_best'
                train_utils.save(n, perf_model)
                #best_acc_contrastive_val = acc_contrastive_val
                best_loss_contrastive_val = loss_contrastive_val.item()
                best_epoch = epoch
            # store the best r-squared value from training
            if val_r_sq > best_valrsq:
                best_valrsq = val_r_sq
            if best_epoch < epoch - 250 and earlystop:
                break
        
        train_utils.save('pc_contrastive_runs/'+NAME, perf_model)
    else:
        try:
            perf_model.load_state_dict(torch.load('pc_contrastive_runs/' + NAME))
        except:
            pass 
예제 #6
0
        log_value('train_accu2', train_accu2, epoch)
        log_value('val_accu2', val_accu2, epoch)
        # print loss
        if epoch % PRINT_EVERY == 0:
            print('[%s (%d %.1f%%)]' % (train_utils.time_since(START), epoch,
                                        float(epoch) / NUM_EPOCHS * 100))
            print('[%s %0.5f, %s %0.5f, %s %0.5f %0.5f]' %
                  ('Train Loss: ', train_loss, ' R-sq: ', train_r_sq, ' Accu:',
                   train_accu, train_accu2))
            print('[%s %0.5f, %s %0.5f, %s %0.5f %0.5f]' %
                  ('Valid Loss: ', val_loss, ' R-sq: ', val_r_sq, ' Accu:',
                   val_accu, val_accu2))
        # save model if best validation loss
        if val_loss < best_val_loss:
            n = file_info + '_best'
            train_utils.save(n, perf_model)
            best_val_loss = val_loss
    print("Saving...")
    train_utils.save(file_info, perf_model)
except KeyboardInterrupt:
    print("Saving before quit...")
    train_utils.save(file_info, perf_model)

# test
# test of full length data
test_loss, test_r_sq, test_accu, test_accu2 = eval_utils.eval_model(
    perf_model, criterion, tef, METRIC, MTYPE, CTYPE)
print('[%s %0.5f, %s %0.5f, %s %0.5f %0.5f]' %
      ('Testing Loss: ', test_loss, ' R-sq: ', test_r_sq, ' Accu:', test_accu,
       test_accu2))
예제 #7
0
def main(args):
    """Performs NAS.
    """
    np.set_printoptions(precision=3)
    save_dir = os.getcwd()

    log = os.path.join(save_dir, "log.txt")

    # Setup SummaryWriter
    summary_dir = os.path.join(save_dir, "summary")
    if not os.path.exists(summary_dir):
        os.mkdir(summary_dir)
    writer = SummaryWriter(summary_dir)

    # own writer that I use to keep track of interesting variables
    own_writer = SummaryWriter(os.path.join(save_dir, 'tensorboard'))

    if args.run.s3_bucket is not None:
        aws_utils.download_from_s3(log, args.run.s3_bucket, log)

        train_utils.copy_code_to_experiment_dir(
            "/home/julienf/git/gaea_release/cnn", save_dir)
        aws_utils.upload_directory(os.path.join(save_dir, "scripts"),
                                   args.run.s3_bucket)

    train_utils.set_up_logging(log)

    if not torch.cuda.is_available():
        logging.info("no gpu device available")
        sys.exit(1)

    torch.cuda.set_device(args.run.gpu)
    logging.info("Search hyperparameters:")
    #logging.info("gpu device = %d" % args.run.gpu)
    logging.info(args.pretty())

    # Set random seeds for random, numpy, torch and cuda
    rng_seed = train_utils.RNGSeed(args.run.seed)

    # Load respective architect
    if args.search.method in ["edarts", "gdarts", "eedarts"]:
        if args.search.fix_alphas:
            from architect.architect_edarts_edge_only import (
                ArchitectEDARTS as Architect, )
        else:
            from architect.architect_edarts import ArchitectEDARTS as Architect
    elif args.search.method in ["darts", "fdarts"]:
        from architect.architect_darts import ArchitectDARTS as Architect
    elif args.search.method == "egdas":
        from architect.architect_egdas import ArchitectEGDAS as Architect
    else:
        raise NotImplementedError

    # Load respective search spaces
    if args.search.search_space in ["darts", "darts_small"]:
        from search_spaces.darts.model_search import DARTSNetwork as Network
    elif "nas-bench-201" in args.search.search_space:
        from search_spaces.nasbench_201.model_search import (
            NASBENCH201Network as Network, )
    elif args.search.search_space == "pcdarts":
        from search_spaces.pc_darts.model_search import PCDARTSNetwork as Network
    else:
        raise NotImplementedError

    if args.train.smooth_cross_entropy:
        criterion = train_utils.cross_entropy_with_label_smoothing
    else:
        criterion = nn.CrossEntropyLoss()

    #num_train, num_classes, train_queue, valid_queue = train_utils.create_data_queues(
    #    args
    #)
    num_classes, (train_queue, train_2_queue), valid_queue, test_queue, (
        number_train, number_valid,
        number_test) = train_utils.create_cifar10_data_queues_own(args)

    logging.info(f"Dataset: {args.run.dataset}, num_classes: {num_classes}")
    logging.info(f"Number of training images: {number_train}")
    if args.search.single_level:
        logging.info(
            f"Number of validation images (unused during search): {number_valid}"
        )
    else:
        logging.info(
            f"Number of validation images (used during search): {number_valid}"
        )
    logging.info(
        f"Number of test images (unused during search): {number_test}")

    model = Network(
        args.train.init_channels, num_classes, args.search.nodes,
        args.train.layers, criterion, **{
            "auxiliary": args.train.auxiliary,
            "search_space_name": args.search.search_space,
            "exclude_zero": args.search.exclude_zero,
            "track_running_stats": args.search.track_running_stats,
        })

    #if args.run.dataset == 'cifar10':
    #    random_img = np.random.randint(0, 255, size=(1, 3, 32, 32))
    #    own_writer.add_graph(model, input_to_model=torch.from_numpy(random_img))

    model = model.cuda()
    logging.info("param size = %fMB",
                 train_utils.count_parameters_in_MB(model))

    optimizer, scheduler = train_utils.setup_optimizer(model, args)

    # TODO: separate args by model, architect, etc
    # TODO: look into using hydra for config files
    architect = Architect(model, args, writer)

    # Try to load a previous checkpoint
    try:
        start_epochs, history, _, _ = train_utils.load(save_dir, rng_seed,
                                                       model, optimizer,
                                                       architect,
                                                       args.run.s3_bucket)
        scheduler.last_epoch = start_epochs - 1
        #(
        #    num_train,
        #    num_classes,
        #    train_queue,
        #    valid_queue,
        #) = train_utils.create_data_queues(args)
        # TODO: why are data queues reloaded?
        num_classes, (train_queue, train_2_queue), valid_queue, test_queue, (
            number_train, number_valid,
            number_test) = train_utils.create_cifar10_data_queues_own(args)
        logging.info(
            'Resumed training from a previous checkpoint. Runtime measurement will be wrong.'
        )
        train_start_time = 0
    except Exception as e:
        logging.info(e)
        start_epochs = 0
        train_start_time = timer()

    best_valid = 0  # for single-level search, corresponds to best train accuracy observed so far
    epoch_best_valid = 0  # for single-level search, corresponds to the epoch of the best observed train accuracy so far
    overall_visualization_time = 0  # don't count visualization into runtime
    for epoch in range(start_epochs, args.run.epochs):
        lr = scheduler.get_lr()[0]
        logging.info(f"| Epoch: {epoch:3d} / {args.run.epochs} | lr: {lr} |")

        model.drop_path_prob = args.train.drop_path_prob * epoch / args.run.epochs

        # training returns top1, loss and top5
        train_acc, train_obj, train_top5 = train(
            args,
            train_queue,
            valid_queue if train_2_queue == None else
            train_2_queue,  # valid_queue for bi-level search, train_2_queue for single-level search
            model,
            architect,
            criterion,
            optimizer,
            lr,
        )
        architect.baseline = train_obj
        architect.update_history()
        architect.log_vars(epoch, writer)

        if "update_lr_state" in dir(scheduler):
            scheduler.update_lr_state(train_obj)

        logging.info(f"| train_acc: {train_acc} |")

        # History tracking
        for vs in [("alphas", architect.alphas), ("edges", architect.edges)]:
            for ct in vs[1]:
                v = vs[1][ct]
                logging.info("{}-{}".format(vs[0], ct))
                logging.info(v)
        # Calling genotypes sets alphas to best arch for EGDAS and MEGDAS
        # so calling here before infer.
        genotype = architect.genotype()
        logging.info("genotype = %s", genotype)

        # log epoch values to tensorboard
        own_writer.add_scalar('Loss/train', train_obj, epoch)
        own_writer.add_scalar('Top1/train', train_acc, epoch)
        own_writer.add_scalar('Top5/train', train_top5, epoch)
        own_writer.add_scalar('lr', lr, epoch)

        # visualize Genotype
        start_visualization = timer()
        genotype_graph_normal = visualize.plot(genotype.normal,
                                               "",
                                               return_type="graph",
                                               output_format='png')
        binary_normal = genotype_graph_normal.pipe()
        stream_normal = io.BytesIO(binary_normal)
        graph_normal = np.array(PIL.Image.open(stream_normal).convert("RGB"))
        own_writer.add_image("Normal_Cell",
                             graph_normal,
                             epoch,
                             dataformats="HWC")
        #del genotype_graph_normal
        #del binary_normal
        #del stream_normal
        #del graph_normal

        genotype_graph_reduce = visualize.plot(genotype.reduce,
                                               "",
                                               return_type="graph",
                                               output_format='png')
        binary_reduce = genotype_graph_reduce.pipe()
        stream_reduce = io.BytesIO(binary_reduce)
        graph_reduce = np.array(PIL.Image.open(stream_reduce).convert("RGB"))
        own_writer.add_image("Reduce_Cell",
                             graph_reduce,
                             epoch,
                             dataformats="HWC")
        #del genotype_graph_reduce
        #del binary_reduce
        #del stream_reduce
        #del graph_reduce
        end_visualization = timer()
        overall_visualization_time += (end_visualization - start_visualization)

        # log validation metrics, but don't utilize them for decisions during single-level search
        valid_acc, valid_obj, valid_top5 = train_utils.infer(
            valid_queue,
            model,
            criterion,
            report_freq=args.run.report_freq,
            discrete=args.search.discrete,
        )
        own_writer.add_scalar('Loss/valid', valid_obj, epoch)
        own_writer.add_scalar('Top1/valid', valid_acc, epoch)
        own_writer.add_scalar('Top5/valid', valid_top5, epoch)
        logging.info(f"| valid_acc: {valid_acc} |")

        if not args.search.single_level:
            if valid_acc > best_valid:
                best_valid = valid_acc
                best_genotype = architect.genotype()
                epoch_best_valid = epoch
        else:
            if train_acc > best_valid:
                best_valid = train_acc
                best_genotype = architect.genotype()
                epoch_best_valid = epoch

        train_utils.save(
            save_dir,
            epoch + 1,
            rng_seed,
            model,
            optimizer,
            architect,
            save_history=True,
            s3_bucket=args.run.s3_bucket,
        )

        scheduler.step()

    train_end_time = timer()
    logging.info(
        f"Visualization of cells during search took a total of {timedelta(seconds=overall_visualization_time)} (hh:mm:ss)."
    )
    logging.info(f"This time is not included in the runtime given below.\n")
    logging.info(
        f"Training finished after {timedelta(seconds=((train_end_time - train_start_time) - overall_visualization_time))}(hh:mm:ss)."
    )  # Performing validation of final epoch...")
    #valid_acc, valid_obj, valid_top5 = train_utils.infer(
    #    valid_queue,
    #    model,
    #    criterion,
    #    report_freq=args.run.report_freq,
    #    discrete=args.search.discrete,
    #)

    #own_writer.add_scalar('Loss/valid', valid_obj, args.run.epochs-1)
    #own_writer.add_scalar('Top1/valid', valid_acc, args.run.epochs-1)
    #own_writer.add_scalar('Top5/valid', valid_top5, args.run.epochs-1)
    #logging.info(f"| valid_acc: {valid_acc} |")

    #if  not args.search.single_level:
    #    if valid_acc > best_valid:
    #        best_valid = valid_acc
    #        best_genotype = architect.genotype()
    #        epoch_best_valid = args.run.epochs-1
    #else:
    #    if train_acc > best_valid:
    #        best_valid = train_acc
    #        best_genotype = architect.genotype()
    #        epoch_best_valid = args.run.epochs-1

    if args.search.single_level:
        logging.info((
            f"\nBecause single-level search is performed, the best genotype was not selected according to the best achieved validation accuracy "
            f"but according to the best train accuracy."))
    logging.info(
        f"\nOverall best found genotype with validation accuracy of {best_valid} (found in epoch {epoch_best_valid}):"
    )
    logging.info(f"{best_genotype}")

    # dump best genotype to json file, so that we can load it during evaluation phase (in train_final.py)
    genotype_dict = best_genotype._asdict()
    for key, val in genotype_dict.items():
        if type(val) == range:
            genotype_dict[key] = [node for node in val]
    if os.path.splitext(args.run.genotype_path)[1] != '.json':
        args.run.genotype_path += '.json'
    with open(args.run.genotype_path, 'w') as genotype_file:
        json.dump(genotype_dict, genotype_file, indent=4)

    logging.info(
        f"Search finished. Dumped best genotype into {args.run.genotype_path}")

    if args.run.s3_bucket is not None:
        filename = "cnn_genotypes.txt"
        aws_utils.download_from_s3(filename, args.run.s3_bucket, filename)

        with open(filename, "a+") as f:
            f.write("\n")
            f.write("{}{}{}{} = {}".format(
                args.search.search_space,
                args.search.method,
                args.run.dataset.replace("-", ""),
                args.run.seed,
                best_genotype,
            ))
        aws_utils.upload_to_s3(filename, args.run.s3_bucket, filename)
        aws_utils.upload_to_s3(log, args.run.s3_bucket, log)
예제 #8
0
def main(args):
    np.set_printoptions(precision=3)
    save_dir = os.getcwd()

    log = os.path.join(save_dir, "log.txt")

    # Setup SummaryWriter
    summary_dir = os.path.join(save_dir, "summary")
    if not os.path.exists(summary_dir):
        os.mkdir(summary_dir)
    writer = SummaryWriter(summary_dir)

    if args.run.s3_bucket is not None:
        aws_utils.download_from_s3(log, args.run.s3_bucket, log)

        train_utils.copy_code_to_experiment_dir("/code/nas-theory/cnn",
                                                save_dir)
        aws_utils.upload_directory(os.path.join(save_dir, "scripts"),
                                   args.run.s3_bucket)

    train_utils.set_up_logging(log)

    if not torch.cuda.is_available():
        logging.info("no gpu device available")
        sys.exit(1)

    torch.cuda.set_device(args.run.gpu)
    logging.info("gpu device = %d" % args.run.gpu)
    logging.info("args = %s", args.pretty())

    rng_seed = train_utils.RNGSeed(args.run.seed)

    if args.search.method in ["edarts", "gdarts", "eedarts"]:
        if args.search.fix_alphas:
            from architect.architect_edarts_edge_only import (
                ArchitectEDARTS as Architect, )
        else:
            from architect.architect_edarts import ArchitectEDARTS as Architect
    elif args.search.method in ["darts", "fdarts"]:
        from architect.architect_darts import ArchitectDARTS as Architect
    elif args.search.method == "egdas":
        from architect.architect_egdas import ArchitectEGDAS as Architect
    else:
        raise NotImplementedError

    if args.search.search_space in ["darts", "darts_small"]:
        from search_spaces.darts.model_search import DARTSNetwork as Network
    elif "nas-bench-201" in args.search.search_space:
        from search_spaces.nasbench_201.model_search import (
            NASBENCH201Network as Network, )
    elif args.search.search_space == "pcdarts":
        from search_spaces.pc_darts.model_search import PCDARTSNetwork as Network
    else:
        raise NotImplementedError

    if args.train.smooth_cross_entropy:
        criterion = train_utils.cross_entropy_with_label_smoothing
    else:
        criterion = nn.CrossEntropyLoss()

    num_train, num_classes, train_queue, valid_queue = train_utils.create_data_queues(
        args)

    print("dataset: {}, num_classes: {}".format(args.run.dataset, num_classes))

    model = Network(
        args.train.init_channels, num_classes, args.search.nodes,
        args.train.layers, criterion, **{
            "auxiliary": args.train.auxiliary,
            "search_space_name": args.search.search_space,
            "exclude_zero": args.search.exclude_zero,
            "track_running_stats": args.search.track_running_stats,
        })
    model = model.cuda()
    logging.info("param size = %fMB",
                 train_utils.count_parameters_in_MB(model))

    optimizer, scheduler = train_utils.setup_optimizer(model, args)

    # TODO: separate args by model, architect, etc
    # TODO: look into using hydra for config files
    architect = Architect(model, args, writer)

    # Try to load a previous checkpoint
    try:
        start_epochs, history = train_utils.load(save_dir, rng_seed, model,
                                                 optimizer, architect,
                                                 args.run.s3_bucket)
        scheduler.last_epoch = start_epochs - 1
        (
            num_train,
            num_classes,
            train_queue,
            valid_queue,
        ) = train_utils.create_data_queues(args)
    except Exception as e:
        logging.info(e)
        start_epochs = 0

    best_valid = 0
    for epoch in range(start_epochs, args.run.epochs):
        lr = scheduler.get_lr()[0]
        logging.info("epoch %d lr %e", epoch, lr)

        model.drop_path_prob = args.train.drop_path_prob * epoch / args.run.epochs

        # training
        train_acc, train_obj = train(
            args,
            train_queue,
            valid_queue,
            model,
            architect,
            criterion,
            optimizer,
            lr,
        )
        architect.baseline = train_obj
        architect.update_history()
        architect.log_vars(epoch, writer)

        if "update_lr_state" in dir(scheduler):
            scheduler.update_lr_state(train_obj)

        logging.info("train_acc %f", train_acc)

        # History tracking
        for vs in [("alphas", architect.alphas), ("edges", architect.edges)]:
            for ct in vs[1]:
                v = vs[1][ct]
                logging.info("{}-{}".format(vs[0], ct))
                logging.info(v)
        # Calling genotypes sets alphas to best arch for EGDAS and MEGDAS
        # so calling here before infer.
        genotype = architect.genotype()
        logging.info("genotype = %s", genotype)

        if not args.search.single_level:
            valid_acc, valid_obj = train_utils.infer(
                valid_queue,
                model,
                criterion,
                report_freq=args.run.report_freq,
                discrete=args.search.discrete,
            )
            if valid_acc > best_valid:
                best_valid = valid_acc
                best_genotype = architect.genotype()
            logging.info("valid_acc %f", valid_acc)

        train_utils.save(
            save_dir,
            epoch + 1,
            rng_seed,
            model,
            optimizer,
            architect,
            save_history=True,
            s3_bucket=args.run.s3_bucket,
        )

        scheduler.step()

    valid_acc, valid_obj = train_utils.infer(
        valid_queue,
        model,
        criterion,
        report_freq=args.run.report_freq,
        discrete=args.search.discrete,
    )
    if valid_acc > best_valid:
        best_valid = valid_acc
        best_genotype = architect.genotype()
    logging.info("valid_acc %f", valid_acc)

    if args.run.s3_bucket is not None:
        filename = "cnn_genotypes.txt"
        aws_utils.download_from_s3(filename, args.run.s3_bucket, filename)

        with open(filename, "a+") as f:
            f.write("\n")
            f.write("{}{}{}{} = {}".format(
                args.search.search_space,
                args.search.method,
                args.run.dataset.replace("-", ""),
                args.run.seed,
                best_genotype,
            ))
        aws_utils.upload_to_s3(filename, args.run.s3_bucket, filename)
        aws_utils.upload_to_s3(log, args.run.s3_bucket, log)