Exemplo n.º 1
0
def train_gpu(
    model: torch.nn.Module,
    train_params: dict,
    data_root: str,
    momentum: float,
    weight_decay: float,
    CIFAR_CLASSES: int,
    learning_rate: float,
    layers: int,
    batch_size: int,
    epochs: int,
    drop_path_prob: float = 0.0,
    save_pth: str = "",
    args: argparse.Namespace = None,
    clip=None,
):

    torch.cuda.set_device(0)
    cudnn.benchmark = True
    cudnn.enabled = True

    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()

    parameters = filter(lambda p: p.requires_grad, model.parameters())

    try:
        weight_decay = args.weight_decay
    except AttributeError:
        logger.warning(
            "Missing weight decay arg - default to 0.001 for lammarckian and 3e-4 otherwise"
        )

        if args.weight_init == 'lammarckian':
            weight_decay = 0.001

    optimizer = torch.optim.SGD(parameters,
                                args.learning_rate,
                                momentum=momentum,
                                weight_decay=weight_decay)

    CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
    CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]

    train_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.RandomCrop(32),
        transforms.RandomHorizontalFlip(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])

    if args.cutout > 0:
        train_transform.transforms.append(Cutout(args.cutout))

    train_data = CifarDataset(transform=train_transform)

    train_queue = torch.utils.data.DataLoader(train_data,
                                              batch_size=batch_size,
                                              pin_memory=True,
                                              num_workers=0)

    valid_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])

    valid_data = my_cifar10.CIFAR10(root=data_root,
                                    train=False,
                                    download=False,
                                    transform=valid_transform)

    valid_queue = torch.utils.data.DataLoader(valid_data,
                                              batch_size=batch_size,
                                              pin_memory=True,
                                              num_workers=0)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, int(epochs))

    for epoch in range(epochs):

        logger.info("epoch %d lr %e" % (epoch, scheduler.get_lr()[0]))
        model.droprate = drop_path_prob * epoch / epochs

        train_acc, train_obj = train(train_queue, model, criterion, optimizer,
                                     train_params, clip)
        logger.info("train_acc %f" % (train_acc))
        scheduler.step()

    valid_acc, valid_obj = infer(valid_queue, model, criterion)
    logger.info("valid_acc %f" % valid_acc)
    model = add_flops_counting_methods(model)
    model.eval()
    model.start_flops_count()
    random_data = torch.randn(1, 3, 32, 32)
    model(torch.autograd.Variable(random_data).to(device))
    n_flops = np.round(model.compute_average_flops_cost() / 1e6, 4)

    return valid_acc, n_flops
Exemplo n.º 2
0
def main(macro_genome, micro_genome, epochs, search_space='micro',
         save='Design_1', expr_root='search', seed=0, gpu=0, init_channels=24,
         layers=11, auxiliary=False, cutout=False, drop_path_prob=0.0, batch_size=128):

    # ---- train logger ----------------- #
    save_pth = os.path.join(expr_root, '{}'.format(save))
    utils.create_exp_dir(save_pth)
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                        format=log_format, datefmt='%m/%d %I:%M:%S %p')

    # ---- parameter values setting ----- #
    CIFAR_CLASSES = config_dict()['n_classes']
    INPUT_CHANNELS = config_dict()['n_channels']
    learning_rate = 0.025
    momentum = 0.9
    weight_decay = 3e-4
    data_root = '../data'
    cutout_length = 16
    auxiliary_weight = 0.4
    grad_clip = 5
    report_freq = 50
    train_params = {
        'auxiliary': auxiliary,
        'auxiliary_weight': auxiliary_weight,
        'grad_clip': grad_clip,
        'report_freq': report_freq,
    }

    if search_space == 'micro' or search_space == 'micro_garbage':
        genome = micro_genome
        genotype = micro_encoding.decode(genome)
        model = Network(init_channels, CIFAR_CLASSES, config_dict()['n_channels'], layers, auxiliary, genotype)
    elif search_space == 'macro' or search_space == 'macro_garbage':
        genome = macro_genome
        genotype = macro_encoding.decode(genome)
        channels = [(INPUT_CHANNELS, init_channels),
                    (init_channels, 2*init_channels),
                    (2*init_channels, 4*init_channels)]
        model = EvoNetwork(genotype, channels, CIFAR_CLASSES, (config_dict()['INPUT_HEIGHT'], config_dict()['INPUT_WIDTH']), decoder='residual')
    elif search_space == 'micromacro':
        genome = [macro_genome, micro_genome]
        macro_genotype = macro_encoding.decode(macro_genome)
        micro_genotype = micro_encoding.decode(micro_genome)
        genotype = [macro_genotype, micro_genotype]
        set_config('micro_creator', make_micro_creator(micro_genotype, convert=False))
        channels = [(INPUT_CHANNELS, init_channels),
                    (init_channels, 2 * init_channels),
                    (2 * init_channels, 4 * init_channels)]
        model = EvoNetwork(macro_genotype, channels, CIFAR_CLASSES,
                           (config_dict()['INPUT_HEIGHT'], config_dict()['INPUT_WIDTH']), decoder='residual')

    else:
        raise NameError('Unknown search space type')

    # logging.info("Genome = %s", genome)
    logging.info("Architecture = %s", genotype)

    torch.cuda.set_device(gpu)
    cudnn.benchmark = True
    torch.manual_seed(seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(seed)

    n_params = (np.sum(np.prod(v.size()) for v in filter(lambda p: p.requires_grad, model.parameters())) / 1e6)
    model = model.to(device)

    logging.info("param size = %fMB", n_params)

    if config_dict()['problem'] == 'classification':
        criterion = nn.CrossEntropyLoss()
    else:
        criterion = nn.MSELoss()
    criterion = criterion.cuda()


    parameters = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = torch.optim.SGD(
        parameters,
        learning_rate,
        momentum=momentum,
        weight_decay=weight_decay
    )

    CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
    CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])

    if cutout:
        train_transform.transforms.append(utils.Cutout(cutout_length))

    train_transform.transforms.append(transforms.Normalize(CIFAR_MEAN, CIFAR_STD))

    valid_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])

    train_data = my_cifar10.CIFAR10(root=data_root, train=True, download=False, transform=train_transform)
    valid_data = my_cifar10.CIFAR10(root=data_root, train=False, download=False, transform=valid_transform)

    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=batch_size,
        # sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True, num_workers=1)

    valid_queue = torch.utils.data.DataLoader(
        valid_data, batch_size=batch_size,
        # sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
        pin_memory=True, num_workers=1)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, int(epochs))

    for epoch in range(epochs):
        scheduler.step()
        logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
        model.droprate = drop_path_prob * epoch / epochs

        train_acc, train_obj = train(train_queue, model, criterion, optimizer, train_params)
        logging.info(f'train_{config_dict()["performance_measure"]} %f', train_acc)

    valid_acc, valid_obj = infer(valid_queue, model, criterion)
    logging.info(f'valid_{config_dict()["performance_measure"]} %f', valid_acc)

    # calculate for flops
    model = add_flops_counting_methods(model)
    model.eval()
    model.start_flops_count()
    random_data = torch.randn(1, INPUT_CHANNELS, config_dict()['INPUT_HEIGHT'], config_dict()['INPUT_WIDTH'])
    model(torch.autograd.Variable(random_data).to(device))
    n_flops = np.round(model.compute_average_flops_cost() / 1e6, 4)
    logging.info('flops = %f', n_flops)

    # save to file
    # os.remove(os.path.join(save_pth, 'log.txt'))
    with open(os.path.join(save_pth, 'log.txt'), "w") as file:
        file.write("Genome = {}\n".format(genome))
        file.write("Architecture = {}\n".format(genotype))
        file.write("param size = {}MB\n".format(n_params))
        file.write("flops = {}MB\n".format(n_flops))
        file.write("valid_acc = {}\n".format(valid_acc))

    # logging.info("Architecture = %s", genotype))

    return {
        'valid_acc': valid_acc,
        'params': n_params,
        'flops': n_flops,
    }
Exemplo n.º 3
0
def main(genome,
         epochs,
         search_space='micro',
         save='Design_1',
         expr_root='search',
         seed=0,
         gpu=0,
         init_channels=24,
         layers=11,
         auxiliary=False,
         cutout=False,
         drop_path_prob=0.0,
         train_dataset="",
         val_dataset=""):

    # ---- train logger ----------------- #
    save_pth = os.path.join(expr_root, '{}'.format(save))
    utils.create_exp_dir(save_pth)
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p')
    # fh = logging.FileHandler(os.path.join(save_pth, 'log.txt'))
    # fh.setFormatter(logging.Formatter(log_format))
    # logging.getLogger().addHandler(fh)

    # ---- parameter values setting ----- #
    NUM_CLASSES = 4
    CIFAR_CLASSES = NUM_CLASSES
    DATA_SHAPE = (128, 128)
    INPUT_CHANNELS = 3
    learning_rate = 0.025
    momentum = 0.9
    weight_decay = 3e-4
    data_root = '../data'
    batch_size = 16
    cutout_length = 16
    auxiliary_weight = 0.4
    grad_clip = 5
    report_freq = 50
    train_params = {
        'auxiliary': auxiliary,
        'auxiliary_weight': auxiliary_weight,
        'grad_clip': grad_clip,
        'report_freq': report_freq,
    }

    if search_space == 'micro':
        genotype = micro_encoding.decode(genome)
        model = Network(init_channels, CIFAR_CLASSES, layers, auxiliary,
                        genotype)
    elif search_space == 'macro':
        genotype = macro_encoding.decode(genome)
        channels = [(INPUT_CHANNELS, init_channels),
                    (init_channels, 2 * init_channels),
                    (2 * init_channels, 4 * init_channels)]
        model = EvoNetwork(genotype,
                           channels,
                           CIFAR_CLASSES,
                           DATA_SHAPE,
                           decoder='residual')
    else:
        raise NameError('Unknown search space type')

    # logging.info("Genome = %s", genome)
    logging.info("Architecture = %s", genotype)

    torch.cuda.set_device(gpu)
    cudnn.benchmark = True
    torch.manual_seed(seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(seed)

    n_params = (np.sum(
        np.prod(v.size())
        for v in filter(lambda p: p.requires_grad, model.parameters())) / 1e6)
    model = model.to(device)

    logging.info("param size = %fMB", n_params)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()

    parameters = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = torch.optim.SGD(parameters,
                                learning_rate,
                                momentum=momentum,
                                weight_decay=weight_decay)

    #TODO: change
    CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
    DATASET_MEAN = [0.4785047, 0.45649716, 0.42604172]
    CIFAR_MEAN = DATASET_MEAN
    DATASET_STD = [0.31962952, 0.3112294, 0.31206125]
    CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]
    CIFAR_STD = DATASET_STD
    #     # data agumentation
    #     train_transform = transforms.Compose([
    #         transforms.RandomCrop(32, padding=4),
    #         transforms.RandomHorizontalFlip(),
    #         transforms.ToTensor()
    #     ])

    #     if cutout:
    #         train_transform.transforms.append(utils.Cutout(cutout_length))

    #     train_transform.transforms.append(transforms.Normalize(CIFAR_MEAN, CIFAR_STD))

    #     valid_transform = transforms.Compose([
    #         transforms.ToTensor(),
    #         transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    #     ])

    #     train_data = my_cifar10.CIFAR10(root=data_root, train=True, download=True, transform=train_transform)
    #     valid_data = my_cifar10.CIFAR10(root=data_root, train=False, download=True, transform=valid_transform)

    #     # num_train = len(train_data)
    #     # indices = list(range(num_train))
    #     # split = int(np.floor(train_portion * num_train))
    train_data = train_dataset
    valid_data = val_dataset
    train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=batch_size,
        # sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=4)

    valid_queue = torch.utils.data.DataLoader(
        valid_data,
        batch_size=batch_size,
        # sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
        pin_memory=True,
        num_workers=4)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, int(epochs))

    for epoch in range(epochs):
        scheduler.step()
        logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
        model.droprate = drop_path_prob * epoch / epochs

        train_acc, train_obj = train(train_queue, model, criterion, optimizer,
                                     train_params)
        logging.info('train_acc %f', train_acc)

    valid_acc, valid_obj = infer(valid_queue, model, criterion)
    logging.info('valid_acc %f', valid_acc)

    # calculate for flops
    model = add_flops_counting_methods(model)
    model.eval()
    model.start_flops_count()
    random_data = torch.randn(1, INPUT_CHANNELS, *DATA_SHAPE)
    model(torch.autograd.Variable(random_data).to(device))
    n_flops = np.round(model.compute_average_flops_cost() / 1e6, 4)
    logging.info('flops = %f', n_flops)

    # save to file
    # os.remove(os.path.join(save_pth, 'log.txt'))
    with open(os.path.join(save_pth, 'log.txt'), "w") as file:
        file.write("Genome = {}\n".format(genome))
        file.write("Architecture = {}\n".format(genotype))
        file.write("param size = {}MB\n".format(n_params))
        file.write("flops = {}MB\n".format(n_flops))
        file.write("valid_acc = {}\n".format(valid_acc))

    # logging.info("Architecture = %s", genotype))

    return {
        'valid_acc': valid_acc,
        'params': n_params,
        'flops': n_flops,
    }
Exemplo n.º 4
0
def main(genome,
         epochs,
         search_space='micro',
         save='Design_1',
         expr_root='search',
         seed=0,
         gpu=0,
         init_channels=24,
         layers=11,
         auxiliary=False,
         cutout=False,
         drop_path_prob=0.0,
         data_path="../data",
         dataset="CIFAR10"):

    # ---- train logger ----------------- #
    save_pth = os.path.join(expr_root, '{}'.format(save))
    utils.create_exp_dir(save_pth)
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p')
    # fh = logging.FileHandler(os.path.join(save_pth, 'log.txt'))
    # fh.setFormatter(logging.Formatter(log_format))
    # logging.getLogger().addHandler(fh)

    # ---- parameter values setting ----- #
    if dataset == "CIFAR10":
        CLASSES = 10
    elif dataset == "CIFAR100":
        CLASSES = 100
    elif dataset == "Sport8":
        CLASSES = 8
    elif dataset == "MIT67":
        CLASSES = 67
    elif dataset == "flowers102":
        CLASSES = 102
    learning_rate = 0.025
    momentum = 0.9
    weight_decay = 3e-4
    data_root = data_path
    batch_size = 128
    cutout_length = 16
    auxiliary_weight = 0.4
    grad_clip = 5
    report_freq = 50
    train_params = {
        'auxiliary': auxiliary,
        'auxiliary_weight': auxiliary_weight,
        'grad_clip': grad_clip,
        'report_freq': report_freq,
    }

    if search_space == 'micro':
        genotype = micro_encoding.decode(genome)
        if dataset == "CIFAR10" or dataset == "CIFAR100":
            model = NetworkCIFAR(init_channels, CLASSES, layers, auxiliary,
                                 genotype)
        else:
            model = NetworkImageNet(init_channels, CLASSES, layers, auxiliary,
                                    genotype)
    elif search_space == 'macro':
        genotype = macro_encoding.decode(genome)
        channels = [(3, init_channels), (init_channels, 2 * init_channels),
                    (2 * init_channels, 4 * init_channels)]
        model = EvoNetwork(genotype,
                           channels,
                           CLASSES, (32, 32),
                           decoder='residual')
    else:
        raise NameError('Unknown search space type')

    # logging.info("Genome = %s", genome)
    logging.info("Architecture = %s", genotype)

    torch.cuda.set_device(gpu)
    cudnn.benchmark = True
    torch.manual_seed(seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(seed)

    n_params = (np.sum(
        np.prod(v.size())
        for v in filter(lambda p: p.requires_grad, model.parameters())) / 1e6)
    model = model.to(device)

    logging.info("param size = %fMB", n_params)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()

    parameters = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = torch.optim.SGD(parameters,
                                learning_rate,
                                momentum=momentum,
                                weight_decay=weight_decay)
    if dataset == "CIFAR10" or dataset == "CIFAR100":
        MEAN = [0.49139968, 0.48215827, 0.44653124]
        STD = [0.24703233, 0.24348505, 0.26158768]

        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])

        if cutout:
            train_transform.transforms.append(utils.Cutout(cutout_length))

        train_transform.transforms.append(transforms.Normalize(MEAN, STD))

        valid_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(MEAN, STD),
        ])
    if dataset == "CIFAR10":
        train_data = my_cifar10.CIFAR10(root=data_root,
                                        train=True,
                                        download=True,
                                        transform=train_transform)
        valid_data = my_cifar10.CIFAR10(root=data_root,
                                        train=True,
                                        download=True,
                                        transform=valid_transform)  #dunno
    elif dataset == "CIFAR100":
        train_data = dset.CIFAR100(root=data_root,
                                   train=True,
                                   download=True,
                                   transform=train_transform)
        valid_data = dset.CIFAR100(root=data_root,
                                   train=True,
                                   download=True,
                                   transform=valid_transform)
    else:
        MEAN = [0.485, 0.456, 0.406]
        STD = [0.229, 0.224, 0.225]
        transf_train = [
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.4,
                                   contrast=0.4,
                                   saturation=0.4,
                                   hue=0.2)
        ]
        transf_val = [
            transforms.Resize(256),
            transforms.CenterCrop(224),
        ]
        normalize = [transforms.ToTensor(), transforms.Normalize(MEAN, STD)]

        train_transform = transforms.Compose(transf_train + normalize)
        valid_transform = transforms.Compose(transf_val + normalize)
        if cutout:
            train_transform.transforms.append(utils.Cutout(cutout_length))

        train_data = dset.ImageFolder(root=data_path + "/" + dataset +
                                      "/train",
                                      transform=train_transform)
        valid_data = dset.ImageFolder(root=data_path + "/" + dataset + "/test",
                                      transform=valid_transform)

    n_train = len(train_data)
    split = n_train // 2
    indices = list(range(n_train))
    random.shuffle(indices)
    train_sampler = torch.utils.data.sampler.SubsetRandomSampler(
        indices[:split])
    valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(
        indices[split:])

    train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=4)

    valid_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(
            indices[split:n_train]),
        pin_memory=True,
        num_workers=4)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, int(epochs))

    for epoch in range(epochs):
        scheduler.step()
        logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
        model.droprate = drop_path_prob * epoch / epochs

        train_acc, train_obj = train(train_queue, model, criterion, optimizer,
                                     train_params)
        logging.info('train_acc %f', train_acc)

    valid_acc, valid_obj = infer(valid_queue, model, criterion)
    logging.info('valid_acc %f', valid_acc)

    # calculate for flops
    model = add_flops_counting_methods(model)
    model.eval()
    model.start_flops_count()
    random_data = torch.randn(1, 3, 32, 32)  #to change
    model(torch.autograd.Variable(random_data).to(device))
    n_flops = np.round(model.compute_average_flops_cost() / 1e6, 4)
    logging.info('flops = %f', n_flops)

    # save to file
    # os.remove(os.path.join(save_pth, 'log.txt'))
    with open(os.path.join(save_pth, 'log.txt'), "w") as file:
        file.write("Genome = {}\n".format(genome))
        file.write("Architecture = {}\n".format(genotype))
        file.write("param size = {}MB\n".format(n_params))
        file.write("flops = {}MB\n".format(n_flops))
        file.write("valid_acc = {}\n".format(valid_acc))
    # logging.info("Architecture = %s", genotype))
    with open(os.path.join(save_pth, 'genotype.txt'), "w") as f:
        f.write(str(genotype))
    return {
        'valid_acc': valid_acc,
        'params': n_params,
        'flops': n_flops,
    }
Exemplo n.º 5
0
    def run():
        """
        Main function to setup the training loop and evaluation loop.
        See comments for detailed explanation.

        Returns:
            None, but it saves the model weights and model performance, based on the get_map_fn arguments

        """

        # xla will assign a device for each forked run of this function
        device = xm.xla_device()

        # determine if this fork is the master fork to avoid logging and print 8 times
        master = xm.is_master_ordinal()

        if master:
            logger.info("running at batch size %i" % batch_size)

        criterion = nn.CrossEntropyLoss()

        criterion.to(device)
        model = WRAPPED_MODEL.to(device)

        # standard data prep
        CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
        CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]

        train_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.RandomCrop(32),
                transforms.RandomHorizontalFlip(),
                transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
            ]
        )

        if args.cutout > 0:
            train_transform.transforms.append(Cutout(args.cutout))

        train_data = CifarDataset(transform=train_transform)

        # distributed samples ensure data is sharded to each tpu core
        # if you do not use this, you are only using 1 of the 8 cores
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_data,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=True,
        )

        train_queue = torch.utils.data.DataLoader(
            train_data,
            batch_size=batch_size//xm.xrt_world_size(),
            sampler=train_sampler,
            drop_last=True,
            num_workers=0,
        )

        valid_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
            ]
        )

        valid_data = my_cifar10.CIFAR10(
            root=data_root, train=False, download=False, transform=valid_transform
        )

        valid_sampler = torch.utils.data.distributed.DistributedSampler(
            valid_data,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=False,
        )

        valid_queue = torch.utils.data.DataLoader(
            valid_data,
            sampler=valid_sampler,
            batch_size=batch_size//xm.xrt_world_size(),
            drop_last=True,
            num_workers=0,
        )

        # standard optimizer stuff
        parameters = filter(lambda p: p.requires_grad, model.parameters())

        if args.opt == "sgd":

            optimizer = torch.optim.SGD(
                parameters,
                args.learning_rate,
                momentum=momentum,
                weight_decay=args.weight_decay,
            )
        elif args.opt == "lamb":
            optimizer = Lamb(
                parameters, lr=args.learning_rate, weight_decay=weight_decay
            )
        else:
            raise NameError("Unknown Optimizer %s" % args.opt)

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, int(epochs))

        # training by epoch loop
        for epoch in range(epochs):

            # the model needs a droprate, so just assign it
            model.droprate = drop_path_prob * epoch / epochs

            start = datetime.datetime.now()
            st = start.strftime("%Y-%m-%d %H:%M:%S")

            if master:
                logger.info("starting epoch %i at %s" % (epoch, st))

            # parallel loader necessary to load data in parallel to each core
            para_loader = pl.ParallelLoader(train_queue, [device]).per_device_loader(
                device
            )
            correct, train_loss, total = train(
                para_loader, model, criterion, optimizer, params, device
            )

            train_acc = 100 * correct / total

            # collect the train accuracies from all cores
            train_acc = xm.mesh_reduce("avg acc", train_acc, np.mean)

            end = datetime.datetime.now()
            duration = (end - start).total_seconds()

            if master:
                logger.info("train_acc %f duration %f" % (train_acc, duration))

            scheduler.step()

        # validate using 8 cores and collect results
        valid_acc, valid_obj = infer(valid_queue, model, criterion, device)
        valid_acc = xm.mesh_reduce("val avg acc", valid_acc, np.mean)

        if master:
            logger.info("valid_acc %f" % valid_acc)

        # count flops
        _ = add_flops_counting_methods(model)
        model.eval()
        model.start_flops_count()
        random_data = torch.randn(1, 3, 32, 32)
        model(torch.autograd.Variable(random_data).to(device))
        n_flops = np.round(model.compute_average_flops_cost() / 1e6, 4)
        n_flops = xm.mesh_reduce("flops", n_flops, np.mean)

        if master:
            logger.info("flops %f" % n_flops)

        if master:
            logger.info("saving")

        # save weights and results

        xm.save([valid_acc, n_flops], "results.pt")