Exemple #1
0
    def _construct_model_from_theta(self, theta):
        model_clone = Network(self.model._C, self.model._num_classes,
                              self.model._layers,
                              self.model._criterion).cuda()

        for x, y in zip(model_clone.arch_parameters(),
                        self.model.arch_parameters()):
            x.data.copy_(y.data)
        model_dict = self.model.state_dict()

        params, offset = {}, 0
        for k, v in self.model.named_parameters():
            v_length = np.prod(v.size())
            params[k] = theta[offset:offset + v_length].view(v.size())
            offset += v_length

        assert offset == len(theta)
        model_dict.update(params)
        model_clone.load_state_dict(model_dict)
        return model_clone.cuda()
Exemple #2
0
optimizer = hvd.DistributedOptimizer(optimizer,
                                     named_parameters=model.named_parameters(),
                                     compression=compression)

arch_optimizer = torch.optim.Adam(model.arch_parameters(),
                                  lr=args.arch_learning_rate,
                                  betas=(0.5, 0.999),
                                  weight_decay=args.arch_weight_decay)

# Restore from a previous checkpoint, if initial_epoch is specified.
# Horovod: restore on the first worker which will broadcast weights to other workers.
if resume_from_epoch > 0 and hvd.rank() == 0:
    filepath = args.checkpoint_format.format(exp=args.save,
                                             epoch=resume_from_epoch)
    checkpoint = torch.load(filepath)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])

# Horovod: broadcast parameters & optimizer state.
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, float(args.epochs), eta_min=args.learning_rate_min)

architect = Architect(model, args)

# model_path = "./search-EXP-final/weights.pt"
# model.load_state_dict(torch.load(model_path))

start_time = time.time()
def main():
    utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
    print(args)

    seed = random.randint(1, 100000000)
    print(seed)

    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    cudnn.enabled = True

    n_channels = 3
    n_bins = 2.**args.n_bits

    # Define model and loss criteria
    model = SearchNetwork(n_channels,
                          args.n_flow,
                          args.n_block,
                          n_bins,
                          affine=args.affine,
                          conv_lu=not args.no_lu)
    model = nn.DataParallel(model, [args.gpu])
    model.load_state_dict(
        torch.load("architecture.pt", map_location="cuda:{}".format(args.gpu)))
    model = model.module
    genotype = model.sample_architecture()

    with open(args.save + '/genotype.pkl', 'wb') as fp:
        pickle.dump(genotype, fp)

    model_single = EnsembleNetwork(n_channels,
                                   args.n_flow,
                                   args.n_block,
                                   n_bins,
                                   genotype,
                                   affine=args.affine,
                                   conv_lu=not args.no_lu)
    model = model_single
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), args.learning_rate)

    dataset = iter(sample_cifar10(args.batch, args.img_size))

    # Sample generated images
    z_sample = []
    z_shapes = calc_z_shapes(n_channels, args.img_size, args.n_flow,
                             args.n_block)
    for z in z_shapes:
        z_new = torch.randn(args.n_sample, *z) * args.temp
        z_sample.append(z_new.to(device))

    with tqdm(range(args.iter)) as pbar:
        for i in pbar:
            # Training procedure
            model.train()

            # Get a random minibatch from the search queue with replacement
            input, _ = next(dataset)
            input = Variable(input,
                             requires_grad=False).cuda(non_blocking=True)

            log_p, logdet, _ = model(input + torch.rand_like(input) / n_bins)

            logdet = logdet.mean()
            loss, _, _ = likelihood_loss(log_p, logdet, args.img_size, n_bins)

            # Optimize model
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            pbar.set_description("Loss: {}".format(loss.item()))

            # Save generated samples
            if i % 100 == 0:
                with torch.no_grad():
                    tvutils.save_image(
                        model_single.reverse(z_sample).cpu().data,
                        "{}/samples/{}.png".format(args.save,
                                                   str(i + 1).zfill(6)),
                        normalize=False,
                        nrow=10,
                    )

            # Save checkpoint
            if i % 1000 == 0:
                utils.save(model, os.path.join(args.save, 'latest_weights.pt'))
Exemple #4
0
    cpm_layers=1,
    auxiliary_loss=False,
)

state_dict = torch.load(args.trained_model,
                        map_location=lambda storage, loc: storage)
print("Pretrained model loading OK...")
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    if "auxiliary" not in k:
        name = k[7:]  # remove module.
        new_state_dict[name] = v
    else:
        print("Auxiliary loss is used when retraining.")

net.load_state_dict(new_state_dict)
net.cuda()
net.eval()
print("Finished loading model!")

transform = TestBaseTransform((104, 117, 123))


def preprocess(img):
    x = torch.from_numpy(transform(img)[0]).permute(2, 0, 1)
    x = x.unsqueeze(0).cuda()
    return x


save_path = args.path + "_res"
os.makedirs(save_path, exist_ok=True)
Exemple #5
0
def main():
    
    
    #LOAD CONFIGS################################################################
    args = get_args()
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_no
    
    
    log_format = '[%(asctime)s] %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO,
        format=log_format, datefmt='%d %I:%M:%S')
    t = time.time()
    local_time = time.localtime(t)
    if not os.path.exists('./log'):
        os.mkdir('./log')
    fh = logging.FileHandler(os.path.join('log/train-{}{:02}{}'.format(local_time.tm_year % 2000, local_time.tm_mon, t)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)
    
    use_gpu = False
    if torch.cuda.is_available():
        use_gpu = True
    
    cudnn.benchmark = True
    torch.manual_seed(args.rand_seed)
    cudnn.enabled=True
    torch.cuda.manual_seed(args.rand_seed)
    
#     cudnn.enabled=True
#     torch.cuda.manual_seed(str(args.rand_seed))
    random.seed(args.rand_seed) 
    #LOAD DATA###################################################################
    
    def convert_param(original_lists):
#       assert isinstance(original_lists, list), 'The type is not right : {:}'.format(original_lists)
      ctype, value = original_lists[0], original_lists[1]
#       assert ctype in support_types, 'Ctype={:}, support={:}'.format(ctype, support_types)
      is_list = isinstance(value, list)
      if not is_list: value = [value]
      outs = []
      for x in value:
        if ctype == 'int':
          x = int(x)
        elif ctype == 'str':
          x = str(x)
        elif ctype == 'bool':
          x = bool(int(x))
        elif ctype == 'float':
          x = float(x)
        elif ctype == 'none':
          if x.lower() != 'none':
            raise ValueError('For the none type, the value must be none instead of {:}'.format(x))
          x = None
        else:
          raise TypeError('Does not know this type : {:}'.format(ctype))
        outs.append(x)
      if not is_list: outs = outs[0]
      return outs
    from collections import namedtuple
    with open('../data/cifar-split.txt', 'r') as f:
        data = json.load(f)
        content = { k: convert_param(v) for k,v in data.items()}
        Arguments = namedtuple('Configure', ' '.join(content.keys()))
        content   = Arguments(**content)
    cifar_split = content
    train_split, valid_split = cifar_split.train, cifar_split.valid
    print(len(train_split),len(valid_split))
    
    if args.dataset == "cifar10":
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std  = [x / 255 for x in [63.0, 62.1, 66.7]]

        lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)]
        transform_train = transforms.Compose(lists)
        transform_test  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
        train_dataset = datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_train)

        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=False,sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split),
            num_workers=4, pin_memory=use_gpu)

        train_dataprovider = DataIterator(train_loader)

        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_test),
            batch_size=250, shuffle=False, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split),
            num_workers=4, pin_memory=use_gpu
        )

        val_dataprovider = DataIterator(val_loader)
        CLASS = 10
    elif args.dataset == "cifar100":
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std  = [x / 255 for x in [68.2, 65.4, 70.4]]

        lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)]
        transform_train = transforms.Compose(lists)
        transform_test  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
        train_dataset = datasets.CIFAR100(root='../data', train=True, download=True, transform=transform_train)
    
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=False,sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split),
            num_workers=4, pin_memory=use_gpu)

        train_dataprovider = DataIterator(train_loader)

        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100(root='../data', train=True, download=True, transform=transform_test),
            batch_size=250, shuffle=False, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split),
            num_workers=4, pin_memory=use_gpu
        )

        val_dataprovider = DataIterator(val_loader)
        CLASS = 100
    elif args.dataset == "svhn":
        mean = [0.4377, 0.4438, 0.4728]
        std  = [0.1980, 0.2010, 0.1970]

        lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)]
        transform_train = transforms.Compose(lists)
        transform_test  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
        train_dataset = datasets.SVHN(root='../data', split='train', download=True, transform=transform_train)
        num_train = len(train_dataset)
        indices = list(range(num_train))
        split = int(np.floor(0.5 * num_train))

        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=False,sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
            num_workers=4, pin_memory=use_gpu)

        train_dataprovider = DataIterator(train_loader)

        val_loader = torch.utils.data.DataLoader(
            datasets.SVHN(root='../data', split='train', download=True, transform=transform_test),
            batch_size=250, shuffle=False, sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
            num_workers=4, pin_memory=use_gpu
        )

        val_dataprovider = DataIterator(val_loader)
        CLASS = 10

        
    print('load data successfully')
    
    model = Network(args.init_channels,  CLASS, args.stacks, eval(args.search_space)).cuda()
    
    
    optimizer = torch.optim.SGD(get_parameters(model),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    
    criterion_smooth = CrossEntropyLabelSmooth(CLASS, 0.1)
    
    if use_gpu:

        loss_function = criterion_smooth.cuda()
        device = torch.device("cuda")
        
    else:
        
        loss_function = criterion_smooth
        device = torch.device("cpu")
        
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,args.total_iters)
    model = model.to(device)
    
    all_iters = 0
    if args.auto_continue:
        lastest_model, iters = get_lastest_model()
        if lastest_model is not None:
            all_iters = iters
            checkpoint = torch.load(lastest_model, map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint['state_dict'], strict=True)
            print('load from checkpoint')
            for i in range(iters):
                scheduler.step()

    args.optimizer = optimizer
    args.loss_function = loss_function
    args.scheduler = scheduler
    args.train_dataprovider = train_dataprovider
    args.val_dataprovider = val_dataprovider
    
    args.evo_controller = evolutionary(args.max_population,args.select_number, args.mutation_len,args.mutation_number,args.p_edgewise,args.p_opwise)
    
    
    
    
    
    path = './record_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}'.format(args.stacks,args.init_channels,args.total_iters,args.warmup_iters,args.max_population,args.select_number,args.mutation_len,args.mutation_number,args.val_interval,args.val_times,args.p_edgewise,args.p_opwise,args.evo_momentum,args.rand_seed,args.search_space,args.dataset)
    
    logging.info(path)
    
#     args.evo_controller.trained_group = args.evo_controller.group
    
    while all_iters < args.total_iters:
        
        
        if all_iters > 1 and all_iters%args.val_interval == 0:
            results = []
            for structure_father in args.evo_controller.group:
                results.append([structure_father.structure,structure_father.loss,structure_father.count])
            if not os.path.exists(path):
                os.mkdir(path)
                
            with open(path + '/%06d-ep.txt'%all_iters,'w') as tt:
                json.dump(results,tt)
            
            if all_iters >= args.warmup_iters:#warmup
                args.evo_controller.select()

            else:
                print("warmup")
            
            
            
        all_iters = train(model, device, args, val_interval=args.val_interval, bn_process=False, all_iters=all_iters)
        validate(model, device, args, all_iters=all_iters)
        
    ###end
#     validate(model, device, args, all_iters=all_iters)
    results = []
    for structure_father in args.evo_controller.group:
        results.append([structure_father.structure,structure_father.loss,structure_father.count])
    with open(path + '/%06d-ep.txt'%all_iters,'w') as tt:
        json.dump(results,tt)
class neural_architecture_search():
    def __init__(self, args):
        self.args = args

        if not torch.cuda.is_available():
            logging.info('no gpu device available')
            sys.exit(1)

        torch.cuda.set_device(self.args.gpu)
        self.device = torch.device("cuda")
        self.rank = 0
        self.seed = self.args.seed
        self.world_size = 1

        if self.args.fix_cudnn:
            random.seed(self.seed)
            torch.backends.cudnn.deterministic = True
            np.random.seed(self.seed)
            cudnn.benchmark = False
            torch.manual_seed(self.seed)
            cudnn.enabled = True
            torch.cuda.manual_seed(self.seed)
            torch.cuda.manual_seed_all(self.seed)
        else:
            np.random.seed(self.seed)
            cudnn.benchmark = True
            torch.manual_seed(self.seed)
            cudnn.enabled = True
            torch.cuda.manual_seed(self.seed)
            torch.cuda.manual_seed_all(self.seed)

        self.path = os.path.join(generate_date, self.args.save)
        if self.rank == 0:
            utils.create_exp_dir(generate_date,
                                 self.path,
                                 scripts_to_save=glob.glob('*.py'))
            logging.basicConfig(stream=sys.stdout,
                                level=logging.INFO,
                                format=log_format,
                                datefmt='%m/%d %I:%M:%S %p')
            fh = logging.FileHandler(os.path.join(self.path, 'log.txt'))
            fh.setFormatter(logging.Formatter(log_format))
            logging.getLogger().addHandler(fh)
            logging.info("self.args = %s", self.args)
            self.logger = tensorboardX.SummaryWriter('./runs/' +
                                                     generate_date + '/' +
                                                     self.args.save_log)
        else:
            self.logger = None

        #initialize loss function
        self.criterion = nn.CrossEntropyLoss().to(self.device)

        #initialize model
        self.init_model()
        if self.args.resume:
            self.reload_model()

        #calculate model param size
        if self.rank == 0:
            logging.info("param size = %fMB",
                         utils.count_parameters_in_MB(self.model))
            self.model._logger = self.logger
            self.model._logging = logging

        #initialize optimizer
        self.init_optimizer()

        #iniatilize dataset loader
        self.init_loaddata()

        self.update_theta = True
        self.update_alpha = True

    def init_model(self):

        self.model = Network(self.args.init_channels, CIFAR_CLASSES,
                             self.args.layers, self.criterion, self.args,
                             self.rank, self.world_size, self.args.steps,
                             self.args.multiplier)
        self.model.to(self.device)
        for v in self.model.parameters():
            if v.requires_grad:
                if v.grad is None:
                    v.grad = torch.zeros_like(v)
        self.model.normal_log_alpha.grad = torch.zeros_like(
            self.model.normal_log_alpha)
        self.model.reduce_log_alpha.grad = torch.zeros_like(
            self.model.reduce_log_alpha)

    def reload_model(self):
        self.model.load_state_dict(torch.load(self.args.resume_path +
                                              '/weights.pt'),
                                   strict=True)

    def init_optimizer(self):

        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         self.args.learning_rate,
                                         momentum=self.args.momentum,
                                         weight_decay=args.weight_decay)

        self.arch_optimizer = torch.optim.Adam(
            self.model.arch_parameters(),
            lr=self.args.arch_learning_rate,
            betas=(0.5, 0.999),
            weight_decay=self.args.arch_weight_decay)

    def init_loaddata(self):

        train_transform, valid_transform = utils._data_transforms_cifar10(
            self.args)
        train_data = dset.CIFAR10(root=self.args.data,
                                  train=True,
                                  download=True,
                                  transform=train_transform)
        valid_data = dset.CIFAR10(root=self.args.data,
                                  train=False,
                                  download=True,
                                  transform=valid_transform)

        if self.args.seed:

            def worker_init_fn():
                seed = self.seed
                np.random.seed(seed)
                random.seed(seed)
                torch.manual_seed(seed)
                return
        else:
            worker_init_fn = None

        num_train = len(train_data)
        indices = list(range(num_train))

        self.train_queue = torch.utils.data.DataLoader(
            train_data,
            batch_size=self.args.batch_size,
            shuffle=True,
            pin_memory=False,
            num_workers=2)

        self.valid_queue = torch.utils.data.DataLoader(
            valid_data,
            batch_size=self.args.batch_size,
            shuffle=False,
            pin_memory=False,
            num_workers=2)

    def main(self):
        # lr scheduler: cosine annealing
        # temp scheduler: linear annealing (self-defined in utils)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer,
            float(self.args.epochs),
            eta_min=self.args.learning_rate_min)

        self.temp_scheduler = utils.Temp_Scheduler(self.args.epochs,
                                                   self.model._temp,
                                                   self.args.temp,
                                                   temp_min=self.args.temp_min)

        for epoch in range(self.args.epochs):
            if self.args.child_reward_stat:
                self.update_theta = False
                self.update_alpha = False

            if self.args.current_reward:
                self.model.normal_reward_mean = torch.zeros_like(
                    self.model.normal_reward_mean)
                self.model.reduce_reward_mean = torch.zeros_like(
                    self.model.reduce_reward_mean)
                self.model.count = 0

            if epoch < self.args.resume_epoch:
                continue
            self.scheduler.step()
            if self.args.temp_annealing:
                self.model._temp = self.temp_scheduler.step()
            self.lr = self.scheduler.get_lr()[0]

            if self.rank == 0:
                logging.info('epoch %d lr %e temp %e', epoch, self.lr,
                             self.model._temp)
                self.logger.add_scalar('epoch_temp', self.model._temp, epoch)
                logging.info(self.model.normal_log_alpha)
                logging.info(self.model.reduce_log_alpha)
                logging.info(F.softmax(self.model.normal_log_alpha, dim=-1))
                logging.info(F.softmax(self.model.reduce_log_alpha, dim=-1))

            genotype_edge_all = self.model.genotype_edge_all()

            if self.rank == 0:
                logging.info('genotype_edge_all = %s', genotype_edge_all)
                # create genotypes.txt file
                txt_name = remark + '_genotype_edge_all_epoch' + str(epoch)
                utils.txt('genotype', self.args.save, txt_name,
                          str(genotype_edge_all), generate_date)

            self.model.train()
            train_acc, loss, error_loss, loss_alpha = self.train(
                epoch, logging)
            if self.rank == 0:
                logging.info('train_acc %f', train_acc)
                self.logger.add_scalar("epoch_train_acc", train_acc, epoch)
                self.logger.add_scalar("epoch_train_error_loss", error_loss,
                                       epoch)
                if self.args.dsnas:
                    self.logger.add_scalar("epoch_train_alpha_loss",
                                           loss_alpha, epoch)

                if self.args.dsnas and not self.args.child_reward_stat:
                    if self.args.current_reward:
                        logging.info('reward mean stat')
                        logging.info(self.model.normal_reward_mean)
                        logging.info(self.model.reduce_reward_mean)
                        logging.info('count')
                        logging.info(self.model.count)
                    else:
                        logging.info('reward mean stat')
                        logging.info(self.model.normal_reward_mean)
                        logging.info(self.model.reduce_reward_mean)
                        if self.model.normal_reward_mean.size(0) > 1:
                            logging.info('reward mean total stat')
                            logging.info(self.model.normal_reward_mean.sum(0))
                            logging.info(self.model.reduce_reward_mean.sum(0))

                if self.args.child_reward_stat:
                    logging.info('reward mean stat')
                    logging.info(self.model.normal_reward_mean.sum(0))
                    logging.info(self.model.reduce_reward_mean.sum(0))
                    logging.info('reward var stat')
                    logging.info(
                        self.model.normal_reward_mean_square.sum(0) -
                        self.model.normal_reward_mean.sum(0)**2)
                    logging.info(
                        self.model.reduce_reward_mean_square.sum(0) -
                        self.model.reduce_reward_mean.sum(0)**2)

            # validation
            self.model.eval()
            valid_acc, valid_obj = self.infer(epoch)
            if self.args.gen_max_child:
                self.args.gen_max_child_flag = True
                valid_acc_max_child, valid_obj_max_child = self.infer(epoch)
                self.args.gen_max_child_flag = False

            if self.rank == 0:
                logging.info('valid_acc %f', valid_acc)
                self.logger.add_scalar("epoch_valid_acc", valid_acc, epoch)
                if self.args.gen_max_child:
                    logging.info('valid_acc_argmax_alpha %f',
                                 valid_acc_max_child)
                    self.logger.add_scalar("epoch_valid_acc_argmax_alpha",
                                           valid_acc_max_child, epoch)

                utils.save(self.model, os.path.join(self.path, 'weights.pt'))

        if self.rank == 0:
            logging.info(self.model.normal_log_alpha)
            logging.info(self.model.reduce_log_alpha)
            genotype_edge_all = self.model.genotype_edge_all()
            logging.info('genotype_edge_all = %s', genotype_edge_all)

    def train(self, epoch, logging):
        objs = utils.AvgrageMeter()
        top1 = utils.AvgrageMeter()
        top5 = utils.AvgrageMeter()
        grad = utils.AvgrageMeter()

        normal_loss_gradient = 0
        reduce_loss_gradient = 0
        normal_total_gradient = 0
        reduce_total_gradient = 0

        loss_alpha = None

        train_correct_count = 0
        train_correct_cost = 0
        train_correct_entropy = 0
        train_correct_loss = 0
        train_wrong_count = 0
        train_wrong_cost = 0
        train_wrong_entropy = 0
        train_wrong_loss = 0

        count = 0
        for step, (input, target) in enumerate(self.train_queue):

            n = input.size(0)
            input = input.to(self.device)
            target = target.to(self.device, non_blocking=True)
            if self.args.snas:
                logits, logits_aux = self.model(input)
                error_loss = self.criterion(logits, target)
                if self.args.auxiliary:
                    loss_aux = self.criterion(logits_aux, target)
                    error_loss += self.args.auxiliary_weight * loss_aux

            if self.args.dsnas:
                logits, error_loss, loss_alpha = self.model(
                    input,
                    target,
                    self.criterion,
                    update_theta=self.update_theta,
                    update_alpha=self.update_alpha)

            for i in range(logits.size(0)):
                index = logits[i].topk(5, 0, True, True)[1]
                if index[0].item() == target[i].item():
                    train_correct_cost += (
                        -logits[i, target[i].item()] +
                        (F.softmax(logits[i]) * logits[i]).sum())
                    train_correct_count += 1
                    discrete_prob = F.softmax(logits[i], dim=-1)
                    train_correct_entropy += -(
                        discrete_prob * torch.log(discrete_prob)).sum(-1)
                    train_correct_loss += -torch.log(discrete_prob)[
                        target[i].item()]
                else:
                    train_wrong_cost += (
                        -logits[i, target[i].item()] +
                        (F.softmax(logits[i]) * logits[i]).sum())
                    train_wrong_count += 1
                    discrete_prob = F.softmax(logits[i], dim=-1)
                    train_wrong_entropy += -(discrete_prob *
                                             torch.log(discrete_prob)).sum(-1)
                    train_wrong_loss += -torch.log(discrete_prob)[
                        target[i].item()]

            num_normal = self.model.num_normal
            num_reduce = self.model.num_reduce

            if self.args.snas or self.args.dsnas:
                loss = error_loss.clone()

            #self.update_lr()

            # logging gradient
            count += 1
            if self.args.snas:
                self.optimizer.zero_grad()
                self.arch_optimizer.zero_grad()
                error_loss.backward(retain_graph=True)
                if not self.args.random_sample:
                    normal_loss_gradient += self.model.normal_log_alpha.grad
                    reduce_loss_gradient += self.model.reduce_log_alpha.grad
                self.optimizer.zero_grad()
                self.arch_optimizer.zero_grad()

            if self.args.snas and (not self.args.random_sample
                                   and not self.args.dsnas):
                loss.backward()

            if not self.args.random_sample:
                normal_total_gradient += self.model.normal_log_alpha.grad
                reduce_total_gradient += self.model.reduce_log_alpha.grad

            nn.utils.clip_grad_norm_(self.model.parameters(),
                                     self.args.grad_clip)
            arch_grad_norm = nn.utils.clip_grad_norm_(
                self.model.arch_parameters(), 10.)

            grad.update(arch_grad_norm)
            if not self.args.fix_weight and self.update_theta:
                self.optimizer.step()
            self.optimizer.zero_grad()

            if not self.args.random_sample and self.update_alpha:
                self.arch_optimizer.step()
            self.arch_optimizer.zero_grad()

            prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))

            objs.update(error_loss.item(), n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)

            if step % self.args.report_freq == 0 and self.rank == 0:
                logging.info('train %03d %e %f %f', step, objs.avg, top1.avg,
                             top5.avg)
                self.logger.add_scalar(
                    "iter_train_top1_acc", top1.avg,
                    step + len(self.train_queue.dataset) * epoch)

        if self.rank == 0:
            logging.info('-------loss gradient--------')
            logging.info(normal_loss_gradient / count)
            logging.info(reduce_loss_gradient / count)
            logging.info('-------total gradient--------')
            logging.info(normal_total_gradient / count)
            logging.info(reduce_total_gradient / count)

        logging.info('correct loss ')
        logging.info((train_correct_loss / train_correct_count).item())
        logging.info('correct entropy ')
        logging.info((train_correct_entropy / train_correct_count).item())
        logging.info('correct cost ')
        logging.info((train_correct_cost / train_correct_count).item())
        logging.info('correct count ')
        logging.info(train_correct_count)

        logging.info('wrong loss ')
        logging.info((train_wrong_loss / train_wrong_count).item())
        logging.info('wrong entropy ')
        logging.info((train_wrong_entropy / train_wrong_count).item())
        logging.info('wrong cost ')
        logging.info((train_wrong_cost / train_wrong_count).item())
        logging.info('wrong count ')
        logging.info(train_wrong_count)

        logging.info('total loss ')
        logging.info(((train_correct_loss + train_wrong_loss) /
                      (train_correct_count + train_wrong_count)).item())
        logging.info('total entropy ')
        logging.info(((train_correct_entropy + train_wrong_entropy) /
                      (train_correct_count + train_wrong_count)).item())
        logging.info('total cost ')
        logging.info(((train_correct_cost + train_wrong_cost) /
                      (train_correct_count + train_wrong_count)).item())
        logging.info('total count ')
        logging.info(train_correct_count + train_wrong_count)

        return top1.avg, loss, error_loss, loss_alpha

    def infer(self, epoch):
        objs = utils.AvgrageMeter()
        top1 = utils.AvgrageMeter()
        top5 = utils.AvgrageMeter()

        self.model.eval()
        with torch.no_grad():
            for step, (input, target) in enumerate(self.valid_queue):
                input = input.to(self.device)
                target = target.to(self.device)
                if self.args.snas:
                    logits, logits_aux = self.model(input)
                    loss = self.criterion(logits, target)
                elif self.args.dsnas:
                    logits, error_loss, loss_alpha = self.model(
                        input, target, self.criterion)
                    loss = error_loss

                prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))

                objs.update(loss.item(), input.size(0))
                top1.update(prec1.item(), input.size(0))
                top5.update(prec5.item(), input.size(0))

                if step % self.args.report_freq == 0 and self.rank == 0:
                    logging.info('valid %03d %e %f %f', step, objs.avg,
                                 top1.avg, top5.avg)
                    self.logger.add_scalar(
                        "iter_valid_loss", loss,
                        step + len(self.valid_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "iter_valid_top1_acc", top1.avg,
                        step + len(self.valid_queue.dataset) * epoch)

        return top1.avg, objs.avg
Exemple #7
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    model = Network(args, args.init_channels, CIFAR_CLASSES, args.layers,
                    criterion, args.steps, args.multiplier)
    model = model.cuda()
    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    if args.resume:
        model.load_state_dict(torch.load(args.resume_path))

    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    train_transform, valid_transform = utils._data_transforms_cifar10(args)
    train_data = dset.CIFAR10(root=args.data,
                              train=True,
                              download=True,
                              transform=train_transform)

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))

    train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=0)

    valid_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(
            indices[split:num_train]),
        pin_memory=True,
        num_workers=0)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs), eta_min=args.learning_rate_min)

    architect = Architect(model, args)

    for epoch in range(args.epochs):
        scheduler.step()
        lr = scheduler.get_lr()[0]
        logging.info('epoch %d lr %e', epoch, lr)

        genotype = model.genotype()
        #logging.info('genotype_edge_all = %s', genotype_edge_all)

        logging.info(model.alphas_normal)
        logging.info(model.alphas_reduce)
        logging.info(F.softmax(model.alphas_normal, dim=-1))
        logging.info(F.softmax(model.alphas_reduce, dim=-1))
        logging.info('genotype = %s', genotype)

        # training
        train_acc, train_obj = train(train_queue, valid_queue, model,
                                     architect, criterion, optimizer, lr)
        logging.info('train_acc %f', train_acc)

        # validation
        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        logging.info('valid_acc %f', valid_acc)

        if args.cal_stat:
            logging.info('normal reward maen')
            logging.info(normal_reward_mean)
            logging.info('normal reward variance')
            logging.info(-normal_reward_mean**2 + normal_reward_mean_square)
            logging.info('reduce reward maen')
            logging.info(reduce_reward_mean)
            logging.info('reduce reward variance')
            logging.info(-reduce_reward_mean**2 + reduce_reward_mean_square)
            logging.info('normal reward total maen')
            logging.info(normal_reward_mean.sum(0))

        utils.save(model, os.path.join(args.save, 'weights.pt'))
Exemple #8
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    model = Network(args.init_channels, CIFAR_CLASSES, args.layers, criterion)
    model = model.cuda()
    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    train_transform, valid_transform = utils._data_transforms_cifar10(args)
    train_data = dset.CIFAR10(root=args.data,
                              train=True,
                              download=True,
                              transform=train_transform)

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))

    train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=2)

    valid_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(
            indices[split:num_train]),
        pin_memory=True,
        num_workers=2)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs), eta_min=args.learning_rate_min)

    start_epoch = 0

    if os.path.isfile((os.path.join(args.save, 'checkpoint.pt'))):
        checkpoint = torch.load(os.path.join(args.save, 'checkpoint.pt'))
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        accuracies = checkpoint['accuracies']
        print(
            'Load checkpoint from {:} with start-epoch = {:}, acc: {}'.format(
                args.save, start_epoch, accuracies))

    architect = Architect(model, args)

    for epoch in range(start_epoch, args.epochs):
        scheduler.step()
        lr = scheduler.get_lr()[0]
        logging.info('epoch %d lr %e', epoch, lr)

        genotype = model.genotype()
        logging.info('genotype = %s', genotype)

        print(F.softmax(model.alphas_normal, dim=-1))
        print(F.softmax(model.alphas_reduce, dim=-1))

        # training
        train_acc, train_obj = train(train_queue, valid_queue, model,
                                     architect, criterion, optimizer, lr)
        logging.info('train_acc %f', train_acc)

        # validation
        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        logging.info('valid_acc %f', valid_acc)

        utils.save(epoch, args, model, optimizer, scheduler, valid_acc,
                   os.path.join(args.save, 'checkpoint.pt'))