def get_cifar_tuned_model(load_weights=True):
    network = NetworkCIFAR(40, CIFAR_CLASSES, 20, True, 0.4, CIFAR_TUNED)
    if load_weights:
        device = torch.device('cpu')
        state_dict = torch.load('weights/cifar_tuned.pt', map_location=device)
        network.load_state_dict(state_dict)
    return network
Esempio n. 2
0
  def random_generate(self):

    num_skip_connect = SearchControllerConf['random_search']['num_identity']
    num_arch = SearchControllerConf['random_search']['num_arch']
    flops_threshold = SearchControllerConf['random_search']['flops_threshold']

    """Random generate the architecture"""
    # k = 2 + 3 + 4 + 5 = 14
    k = sum(1 for i in range(self._steps) for n in range(2+i))
    num_ops = len(PRIMITIVES)

    self.random_arch_list = []
    for ai in range(num_arch):
      seed = random.randint(0, 1000)
      torch.manual_seed(seed)
      while True:
        self.alphas_normal = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=False)
        self.alphas_reduce = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=False)
        arch = self.genotype()
        # if the skip connect meet num_skip_connect
        op_names, indices = zip(*arch.normal)
        cnt = 0
        for name, index in zip(op_names, indices):
          if name == 'skip_connect':
            cnt += 1
        if cnt == num_skip_connect:
          # the flops threshold
          model = NetworkCIFAR(36, 10, 20, True, arch, False)
          flops, params = profile(model, inputs=(torch.randn(1, 3, 32, 32),), verbose=False)
          if flops / 1e6 >= flops_threshold:
            self.random_arch_list += [('arch_' + str(ai), arch)]
            break
          else:
            continue

    return self.random_arch_list
Esempio n. 3
0
def process_logs(args) -> DataFrame:
    data = []
    for log in args.train:
        row = []
        try:
            # evaluation stage metrics
            lines = str(log.readlines())
            match = re.search(r"arch='(?P<name>.*?)'", lines)
            name = match.group("name")
            row.append(name)
            # l2_loss_2e01 -> 2e-01
            weight_value = float(name.split("_")[2].replace("e", "e-"))
            row.append(weight_value)
            match = re.search(r"param size.*?(?P<value>\d*\.\d+)MB", lines)
            param_size = float(match.group("value"))
            row.append(param_size)
            for metric in [
                    TRAIN_LOSS, TRAIN_ACC, VALID_LOSS, VALID_ACC, TEST_LOSS,
                    TEST_ACC
            ]:
                value = float(
                    re.findall(rf'{metric}(?:uracy)? (?P<value>\d*\.\d+)',
                               lines)[-1])
                row.append(value)
        except Exception as e:
            print(f"Error '{e}' while processing file {log.name}")
            while len(row) < 9:
                row.append(None)

        try:
            # search stage metrics
            genotype = genotypes.__dict__[name]
            genotype_str = str(genotype)
            match = False
            for s_log in args.search:
                s_lines = str(s_log.readlines())
                s_log.seek(0, 0)
                # ((?!\\n).)* = anything except new line escaped
                match = re.search(
                    r"stats = (?P<stats>{((?!\\n).)*" +
                    re.escape(genotype_str) + r".*?})\\n\",", s_lines)
                if match:
                    stats = eval(match.group("stats"))
                    # L2 loss case
                    if list(stats.get(L1_LOSS).keys())[0][0] == -1:
                        LOSS = L2_LOSS
                    # L1 loss case
                    elif list(stats.get(L2_LOSS).keys())[0][0] == -1:
                        LOSS = L1_LOSS
                    else:
                        raise Exception("L1 and L2 loss have w = -1")
                    values = list(stats.get(LOSS).values())[0]
                    search_criterion_loss = values[CRITERION_LOSS]
                    search_reg_loss = values[REG_LOSS]
                    row.append(search_criterion_loss)
                    row.append(search_reg_loss)
                    search_acc = values[VALID_ACC]
                    row.append(search_acc)
                    break
            if not match:
                raise Exception(f"Didn't find {name} on eval logs")
        except Exception as e:
            print(f"Error '{e}' while processing file {log.name}")
            while len(row) < 12:
                row.append(None)

        try:
            # model profiling
            genotype = genotypes.__dict__[name]
            match = re.search(r"init_channels=(?P<value>\d+)", lines)
            init_channels = int(match.group("value"))
            match = re.search(r"layers=(?P<value>\d+)", lines)
            layers = int(match.group("value"))
            match = re.search(r"drop_path_prob=(?P<value>\d+\.\d+)", lines)
            drop_path_prob = float(match.group("value"))
            match = re.search(r"auxiliary=(?P<value>\w+)", lines)
            auxiliary = bool(match.group("value"))
            model = NetworkCIFAR(init_channels, 10, layers, auxiliary,
                                 genotype)
            model.cuda()
            model.drop_path_prob = drop_path_prob
            parameters, net_flops, total_time_gpu, total_time_cpu = model_profiling(
                model, name)
            row.append(parameters)
            row.append(net_flops)
            row.append(total_time_gpu)
            row.append(total_time_cpu)
        except Exception as e:
            print(f"Error '{e}' while processing file {log.name}")

        if len(row) > 0:
            data.append(row)
    df = pd.DataFrame(data,
                      columns=[
                          MODEL_NAME, WEIGHT, PARAMETERS_DARTS, TRAIN_LOSS,
                          TRAIN_ACC, VALID_LOSS, VALID_ACC, TEST_LOSS,
                          TEST_ACC, SEARCH_CRIT_LOSS, SEARCH_REG_LOSS,
                          SEARCH_ACC, PARAMETERS_OFA, FLOPS, LATENCY_GPU,
                          LATENCY_CPU
                      ])
    df.set_index(keys=MODEL_NAME, inplace=True)
    df.sort_values(by=WEIGHT, inplace=True, ascending=False)
    pd.set_option("display.max_rows", None, "display.max_columns", None,
                  "display.width", None)
    print(df)
    df.to_csv(args.output)
    return df
Esempio n. 4
0
def train_arch(stage, step, valid_queue, model, optimizer_a, cur_switches_normal, cur_switches_reduce ):
    global best_prec1
    global best_normal_indices
    global best_reduce_indices
    # for step in range(100):
    try:
        input_search, target_search = next(valid_queue_iter)
    except:
        valid_queue_iter = iter(valid_queue)
        input_search, target_search = next(valid_queue_iter)
    input_search = input_search.cuda()
    target_search = target_search.cuda(non_blocking=True)
    normal_grad_buffer = []
    reduce_grad_buffer = []
    reward_buffer = []
    params_buffer = []
    flops_list = []
    params_list = []
    # cifar_mu = np.ones((3, 32, 32))
    # cifar_mu[0, :, :] = 0.4914
    # cifar_mu[1, :, :] = 0.4822
    # cifar_mu[2, :, :] = 0.4465

# (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

    # cifar_std = np.ones((3, 32, 32))
    # cifar_std[0, :, :] = 0.2471
    # cifar_std[1, :, :] = 0.2435
    # cifar_std[2, :, :] = 0.2616
    # criterion = nn.CrossEntropyLoss()
    # optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    # classifier = PyTorchClassifier(
    #     model=model,
    #     clip_values=(0.0, 1.0),
    #     preprocessing=(cifar_mu, cifar_std),
    #     loss=criterion,
    #     optimizer=optimizer,
    #     input_shape=(3, 32, 32),
    #     nb_classes=10,
    # )

    for batch_idx in range(model.module.rl_batch_size): # 多采集几个网络,测试
        # sample the submodel
        # if stage == 1:
        #     print("ok")
        normal_indices, reduce_indices, genotype = get_cur_model(model, cur_switches_normal, cur_switches_reduce)
        # return 0.0, 0.0
        # attack = FastGradientMethod(estimator=model, eps=0.2)
        # x_test_adv = attack.generate(x=x_test)
        # res = clever_u(classifier,valid_queue.dataset.data[-1].transpose(2,0,1) , 2, 2, R_LI, norm=np.inf, pool_factor=3)
        # print(res)
        # validat the sub_model
        with torch.no_grad():
            logits= model(input_search)
            prec1, _ = utils.accuracy(logits, target_search, topk=(1,5))
        sub_model = NetworkCIFAR(args.init_channels, CIFAR_CLASSES, 20, False, genotype)
        sub_model.drop_path_prob = 0
        # para0 = utils.count_parameters_in_MB(sub_model)
        input = torch.randn(1,3,32,32)
        flops, params = profile(sub_model, inputs = (input,), )
        flops_s, params_s = clever_format([flops, params], "%.3f")
        flops, params = flops/1e9, params/1e6
        params_buffer.append(params)
        flops_list.append(flops_s)
        params_list.append(params_s)


            # prec1 = np.random.rand()
        if model.module._arch_parameters[0].grad is not None:
            model.module._arch_parameters[0].grad.data.zero_()
        if model.module._arch_parameters[1].grad is not None:
            model.module._arch_parameters[1].grad.data.zero_()
        obj_term = 0
        for i in range(14):
            obj_term = obj_term + model.module.normal_log_prob[i]
            obj_term = obj_term + model.module.reduce_log_prob[i]
        loss_term = -obj_term
        # backward
        loss_term.backward()
        # take out gradient dict
        normal_grad_buffer.append(model.module._arch_parameters[0].grad.data.clone())
        reduce_grad_buffer.append(model.module._arch_parameters[1].grad.data.clone())
        reward = calculate_reward(stage, prec1, params)
        reward_buffer.append(reward)
        # recode best_reward index
        if prec1 > best_prec1:
            best_prec1 = prec1
            best_normal_indices = normal_indices
            best_reduce_indices = reduce_indices
        # else:
        #     best_normal_indices = []
        #     best_reduce_indices = []
    logging.info(flops_list)
    logging.info(params_list)
    logging.info(normal_indices.detach().cpu().numpy().squeeze())
    logging.info(reduce_indices.detach().cpu().numpy().squeeze())
    logging.info(genotype)
    avg_reward = sum(reward_buffer) / model.module.rl_batch_size
    avg_params = sum(params_buffer) / model.module.rl_batch_size
    if model.module.baseline == 0:
        model.module.baseline = avg_reward
    else:
        model.module.baseline += model.module.baseline_decay_weight * (avg_reward - model.module.baseline) # hs
        # model.module.baseline = model.module.baseline_decay_weight * model.module.baseline + \
        #                         (1-model.module.baseline_decay_weight) * avg_reward

    model.module._arch_parameters[0].grad.data.zero_()
    model.module._arch_parameters[1].grad.data.zero_()
    for j in range(model.module.rl_batch_size):
        model.module._arch_parameters[0].grad.data += (reward_buffer[j] - model.module.baseline) * normal_grad_buffer[j]
        model.module._arch_parameters[1].grad.data += (reward_buffer[j] - model.module.baseline) * reduce_grad_buffer[j]
    model.module._arch_parameters[0].grad.data /= model.module.rl_batch_size
    model.module._arch_parameters[1].grad.data /= model.module.rl_batch_size
    # if step % 50 == 0:
    #     logging.info(model.module._arch_parameters[0].grad.data)
    #     logging.info(model.module._arch_parameters[0])
    # apply gradients
    nn.utils.clip_grad_norm_(model.module.arch_parameters(), args.grad_clip)
    optimizer_a.step()

    if step % args.report_freq == 0:
        #     logging.info(model.module._arch_parameters[0])
        # valid the argmax arch
        logging.info('REINFORCE [step %d]\t\tMean Reward %.4f\tBaseline %.4f\tBest Sampled Prec1 %.4f', step, avg_reward, model.module.baseline, best_prec1)
        max_normal_index, max_reduce_index = set_max_model(model, cur_switches_normal, cur_switches_reduce)
        logits= model(input_search)
        prec1, _ = utils.accuracy(logits, target_search, topk=(1,5))
        logging.info('REINFORCE [step %d]\t\tCurrent Max Architecture Reward %.4f\t\tAvarage Params %.3f', step, prec1/100, avg_params)
        max_arch_reward_writer.add_scalar('max_arch_reward_{}'.format(stage), prec1, tb_index[stage])
        avg_params_writer.add_scalar('avg_params_{}'.format(stage), avg_params, tb_index[stage])
        logging.info(max_normal_index)
        logging.info(max_reduce_index)
        best_reward_arch_writer.add_scalar('best_prec1_arch_{}'.format(stage), best_prec1, tb_index[stage])

        logging.info(np.around(torch.Tensor(reward_buffer).numpy(),3))
        # logging.info(model.module.normal_probs)
        # logging.info(model.module.reduce_probs)
        logging.info(model.module.alphas_normal)
        logging.info(model.module.alphas_reduce)

        for i in range(14):
            normal_max_writer[i].add_scalar('normal_max_arch_{}'.format(stage), np.argmax(model.module.normal_probs.detach().cpu()[i].numpy()), tb_index[stage])

            normal_min_k = get_min_k(model.module.normal_probs.detach().cpu()[i].numpy(), normal_num_to_drop[stage])
            for j in range(normal_num_to_drop[stage]):
                normal_min_writer[i][j].add_scalar('normal_min_arch_{}_{}'.format(stage, j), normal_min_k[j], tb_index[stage])

            best_normal_writer[i].add_scalar('best_normal_index_{}'.format(stage), best_normal_indices[i].cpu().numpy(), tb_index[stage])

        for i in range(14):
            reduce_max_writer[i].add_scalar('reduce_max_arch_{}'.format(stage), np.argmax(model.module.reduce_probs.detach().cpu()[i].numpy()), tb_index[stage])

            reduce_min_k = get_min_k(model.module.reduce_probs.detach().cpu()[i].numpy(), reduce_num_to_drop[stage])
            for j in range(reduce_num_to_drop[stage]):
                reduce_min_writer[i][j].add_scalar('reduce_min_arch_{}_{}'.format(stage, j), reduce_min_k[j], tb_index[stage])

            best_reduce_writer[i].add_scalar('best_reduce_index_{}'.format(stage), best_reduce_indices[i].cpu().numpy(), tb_index[stage])

        best_prec1 = 0
        tb_index[stage]+=1

    model.module.restore_super_net()
Esempio n. 5
0
def process_logs(args) -> DataFrame:
    # filter logs
    lines = str(args.log.readlines())
    match = re.search(r"Selected regularization(?P<reg>.*?)\\n", lines)
    reg = match.group("reg")
    if L1_LOSS in reg:
        loss = L1_LOSS
    elif L2_LOSS in reg:
        loss = L2_LOSS
    else:
        raise RuntimeError("Cant decode line Selected regularization")
    match = re.finditer(r"hist = (?P<hist>.*?)\\n", lines)
    hist_str = list(match)[-1].group("hist")
    hist = eval(hist_str)[loss]
    print("Removing non-optimal samples")
    # filter out dominated points
    filter_hist(hist)
    data = []
    for weight, result in hist.items():
        row = []
        name = create_genotype_name(weight, loss)
        try:
            row.append(name)
            weight_value = weight[0]
            row.append(weight_value)
            # {'train_acc': 25.035999994506835,
            # 'valid_acc': 20.171999999084473,
            # 'reg_loss': 16.01249122619629,
            # 'criterion_loss': 1.9922981262207031,
            # 'model_size': 1.81423,
            # 'genotype': Genotype(..)}
            for metric in [
                    SIZE, TRAIN_ACC, VALID_ACC, CRITERION_LOSS, REG_LOSS
            ]:
                row.append(result[metric])
        except Exception as e:
            print(f"Error '{e}' while processing file {args.log} w={weight}")
            while len(row) < 7:
                row.append(None)

        try:
            # model profiling
            genotype = result[GENOTYPE]
            # using default from train.py for CIFAR10
            model = NetworkCIFAR(36, 10, 20, False, genotype)
            model.cuda()
            model.drop_path_prob = 0.3
            parameters, net_flops, total_time_gpu, total_time_cpu = model_profiling(
                model, name)
            row.append(parameters)
            row.append(net_flops)
            row.append(total_time_gpu)
            row.append(total_time_cpu)
        except Exception as e:
            print(f"Error '{e}' while processing file {args.log} w={weight}")
            raise e

        if len(row) > 0:
            data.append(row)
    df = pd.DataFrame(data,
                      columns=[
                          MODEL_NAME, WEIGHT, "Params", SEARCH_TRAIN_ACC,
                          SEARCH_VAL_ACC, SEARCH_CRIT_LOSS, SEARCH_REG_LOSS,
                          "Parameters", FLOPS, "Latency GPU", LATENCY_CPU
                      ])
    df.set_index(keys=MODEL_NAME, inplace=True)
    df.sort_values(by=WEIGHT, inplace=True, ascending=False)
    pd.set_option("display.max_rows", None, "display.max_columns", None,
                  "display.width", None)
    print(df)
    df.to_csv(args.output)
    return df
Esempio n. 6
0
def main():
  parser = argparse.ArgumentParser("Common Argument Parser")
  parser.add_argument('--data', type=str, default='../data', help='location of the data corpus')
  parser.add_argument('--dataset', type=str, default='cifar10', help='which dataset:\
                      cifar10, mnist, emnist, fashion, svhn, stl10, devanagari')
  parser.add_argument('--batch_size', type=int, default=64, help='batch size')
  parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate')
  parser.add_argument('--learning_rate_min', type=float, default=1e-8, help='min learning rate')
  parser.add_argument('--lr_power_annealing_exponent_order', type=float, default=2,
                      help='Cosine Power Annealing Schedule Base, larger numbers make '
                           'the exponential more dominant, smaller make cosine more dominant, '
                           '1 returns to standard cosine annealing.')
  parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
  parser.add_argument('--weight_decay', '--wd', dest='weight_decay', type=float, default=3e-4, help='weight decay')
  parser.add_argument('--partial', default=1/8, type=float, help='partially adaptive parameter p in Padam')
  parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
  parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
  parser.add_argument('--epochs', type=int, default=2000, help='num of training epochs')
  parser.add_argument('--start_epoch', default=1, type=int, metavar='N',
                      help='manual epoch number (useful for restarts)')
  parser.add_argument('--warmup_epochs', type=int, default=5, help='num of warmup training epochs')
  parser.add_argument('--warm_restarts', type=int, default=20, help='warm restarts of cosine annealing')
  parser.add_argument('--init_channels', type=int, default=36, help='num of init channels')
  parser.add_argument('--mid_channels', type=int, default=32, help='C_mid channels in choke SharpSepConv')
  parser.add_argument('--layers', type=int, default=20, help='total number of layers')
  parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model')
  parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower')
  parser.add_argument('--mixed_auxiliary', action='store_true', default=False, help='Learn weights for auxiliary networks during training. Overrides auxiliary flag')
  parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss')
  parser.add_argument('--cutout', action='store_true', default=False, help='use cutout')
  parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
  parser.add_argument('--autoaugment', action='store_true', default=False, help='use cifar10 autoaugment https://arxiv.org/abs/1805.09501')
  parser.add_argument('--random_eraser', action='store_true', default=False, help='use random eraser')
  parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path probability')
  parser.add_argument('--save', type=str, default='EXP', help='experiment name')
  parser.add_argument('--seed', type=int, default=0, help='random seed')
  parser.add_argument('--arch', type=str, default='DARTS', help='which architecture to use')
  parser.add_argument('--ops', type=str, default='OPS', help='which operations to use, options are OPS and DARTS_OPS')
  parser.add_argument('--primitives', type=str, default='PRIMITIVES',
                      help='which primitive layers to use inside a cell search space,'
                           ' options are PRIMITIVES, SHARPER_PRIMITIVES, and DARTS_PRIMITIVES')
  parser.add_argument('--optimizer', type=str, default='sgd', help='which optimizer to use, options are padam and sgd')
  parser.add_argument('--load', type=str, default='',  metavar='PATH', help='load weights at specified location')
  parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
  parser.add_argument('--flops', action='store_true', default=False, help='count flops and exit, aka floating point operations.')
  parser.add_argument('-e', '--evaluate', dest='evaluate', type=str, metavar='PATH', default='',
                      help='evaluate model at specified path on training, test, and validation datasets')
  parser.add_argument('--multi_channel', action='store_true', default=False, help='perform multi channel search, a completely separate search space')
  parser.add_argument('--load_args', type=str, default='',  metavar='PATH',
                      help='load command line args from a json file, this will override '
                           'all currently set args except for --evaluate, and arguments '
                           'that did not exist when the json file was originally saved out.')
  parser.add_argument('--layers_of_cells', type=int, default=8, help='total number of cells in the whole network, default is 8 cells')
  parser.add_argument('--layers_in_cells', type=int, default=4,
                      help='Total number of nodes in each cell, aka number of steps,'
                           ' default is 4 nodes, which implies 8 ops')
  parser.add_argument('--weighting_algorithm', type=str, default='scalar',
                    help='which operations to use, options are '
                         '"max_w" (1. - max_w + w) * op, and scalar (w * op)')
  # TODO(ahundt) remove final path and switch back to genotype
  parser.add_argument('--load_genotype', type=str, default=None, help='Name of genotype to be used')
  parser.add_argument('--simple_path', default=True, action='store_false', help='Final model is a simple path (MultiChannelNetworkModel)')
  args = parser.parse_args()

  args = utils.initialize_files_and_args(args)

  logger = utils.logging_setup(args.log_file_path)

  if not torch.cuda.is_available():
    logger.info('no gpu device available')
    sys.exit(1)

  np.random.seed(args.seed)
  torch.cuda.set_device(args.gpu)
  cudnn.benchmark = True
  torch.manual_seed(args.seed)
  cudnn.enabled=True
  torch.cuda.manual_seed(args.seed)
  logger.info('gpu device = %d' % args.gpu)
  logger.info("args = %s", args)

  DATASET_CLASSES = dataset.class_dict[args.dataset]
  DATASET_CHANNELS = dataset.inp_channel_dict[args.dataset]
  DATASET_MEAN = dataset.mean_dict[args.dataset]
  DATASET_STD = dataset.std_dict[args.dataset]
  logger.info('output channels: ' + str(DATASET_CLASSES))

  # # load the correct ops dictionary
  op_dict_to_load = "operations.%s" % args.ops
  logger.info('loading op dict: ' + str(op_dict_to_load))
  op_dict = eval(op_dict_to_load)

  # load the correct primitives list
  primitives_to_load = "genotypes.%s" % args.primitives
  logger.info('loading primitives:' + primitives_to_load)
  primitives = eval(primitives_to_load)
  logger.info('primitives: ' + str(primitives))

  genotype = eval("genotypes.%s" % args.arch)
  # create the neural network

  criterion = nn.CrossEntropyLoss()
  criterion = criterion.cuda()
  if args.multi_channel:
    final_path = None
    if args.load_genotype is not None:
      genotype = getattr(genotypes, args.load_genotype)
      print(genotype)
      if type(genotype[0]) is str:
        logger.info('Path :%s', genotype)
    # TODO(ahundt) remove final path and switch back to genotype
    cnn_model = MultiChannelNetwork(
      args.init_channels, DATASET_CLASSES, layers=args.layers_of_cells, criterion=criterion, steps=args.layers_in_cells,
      weighting_algorithm=args.weighting_algorithm, genotype=genotype)
    flops_shape = [1, 3, 32, 32]
  elif args.dataset == 'imagenet':
      cnn_model = NetworkImageNet(args.init_channels, DATASET_CLASSES, args.layers, args.auxiliary, genotype, op_dict=op_dict, C_mid=args.mid_channels)
      flops_shape = [1, 3, 224, 224]
  else:
      cnn_model = NetworkCIFAR(args.init_channels, DATASET_CLASSES, args.layers, args.auxiliary, genotype, op_dict=op_dict, C_mid=args.mid_channels)
      flops_shape = [1, 3, 32, 32]
  cnn_model = cnn_model.cuda()

  logger.info("param size = %fMB", utils.count_parameters_in_MB(cnn_model))
  if args.flops:
    logger.info('flops_shape = ' + str(flops_shape))
    logger.info("flops = " + utils.count_model_flops(cnn_model, data_shape=flops_shape))
    return

  optimizer = torch.optim.SGD(
      cnn_model.parameters(),
      args.learning_rate,
      momentum=args.momentum,
      weight_decay=args.weight_decay
      )

  # Get preprocessing functions (i.e. transforms) to apply on data
  train_transform, valid_transform = utils.get_data_transforms(args)
  if args.evaluate:
    # evaluate the train dataset without augmentation
    train_transform = valid_transform

  # Get the training queue, use full training and test set
  train_queue, valid_queue = dataset.get_training_queues(
    args.dataset, train_transform, valid_transform, args.data, args.batch_size, train_proportion=1.0, search_architecture=False)

  test_queue = None
  if args.dataset == 'cifar10':
    # evaluate best model weights on cifar 10.1
    # https://github.com/modestyachts/CIFAR-10.1
    test_data = cifar10_1.CIFAR10_1(root=args.data, download=True, transform=valid_transform)
    test_queue = torch.utils.data.DataLoader(
      test_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8)

  if args.evaluate:
    # evaluate the loaded model, print the result, and return
    logger.info("Evaluating inference with weights file: " + args.load)
    eval_stats = evaluate(
      args, cnn_model, criterion, args.load,
      train_queue=train_queue, valid_queue=valid_queue, test_queue=test_queue)
    with open(args.stats_file, 'w') as f:
      arg_dict = vars(args)
      arg_dict.update(eval_stats)
      json.dump(arg_dict, f)
    logger.info("flops = " + utils.count_model_flops(cnn_model))
    logger.info(utils.dict_to_log_string(eval_stats))
    logger.info('\nEvaluation of Loaded Model Complete! Save dir: ' + str(args.save))
    return

  lr_schedule = cosine_power_annealing(
    epochs=args.epochs, max_lr=args.learning_rate, min_lr=args.learning_rate_min,
    warmup_epochs=args.warmup_epochs, exponent_order=args.lr_power_annealing_exponent_order)
  epochs = np.arange(args.epochs) + args.start_epoch
  # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs))
  epoch_stats = []

  stats_csv = args.epoch_stats_file
  stats_csv = stats_csv.replace('.json', '.csv')
  with tqdm(epochs, dynamic_ncols=True) as prog_epoch:
    best_valid_acc = 0.0
    best_epoch = 0
    best_stats = {}
    stats = {}
    epoch_stats = []
    weights_file = os.path.join(args.save, 'weights.pt')
    for epoch, learning_rate in zip(prog_epoch, lr_schedule):
      # update the drop_path_prob augmentation
      cnn_model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
      # update the learning rate
      for param_group in optimizer.param_groups:
        param_group['lr'] = learning_rate
      # scheduler.get_lr()[0]

      train_acc, train_obj = train(args, train_queue, cnn_model, criterion, optimizer)

      val_stats = infer(args, valid_queue, cnn_model, criterion)
      stats.update(val_stats)
      stats['train_acc'] = train_acc
      stats['train_loss'] = train_obj
      stats['lr'] = learning_rate
      stats['epoch'] = epoch

      if stats['valid_acc'] > best_valid_acc:
        # new best epoch, save weights
        utils.save(cnn_model, weights_file)
        best_epoch = epoch
        best_stats.update(copy.deepcopy(stats))
        best_valid_acc = stats['valid_acc']
        best_train_loss = train_obj
        best_train_acc = train_acc
      # else:
      #   # not best epoch, load best weights
      #   utils.load(cnn_model, weights_file)
      logger.info('epoch, %d, train_acc, %f, valid_acc, %f, train_loss, %f, valid_loss, %f, lr, %e, best_epoch, %d, best_valid_acc, %f, ' + utils.dict_to_log_string(stats),
                  epoch, train_acc, stats['valid_acc'], train_obj, stats['valid_loss'], learning_rate, best_epoch, best_valid_acc)
      stats['train_acc'] = train_acc
      stats['train_loss'] = train_obj
      epoch_stats += [copy.deepcopy(stats)]
      with open(args.epoch_stats_file, 'w') as f:
        json.dump(epoch_stats, f, cls=utils.NumpyEncoder)
      utils.list_of_dicts_to_csv(stats_csv, epoch_stats)

    # get stats from best epoch including cifar10.1
    eval_stats = evaluate(args, cnn_model, criterion, weights_file, train_queue, valid_queue, test_queue)
    with open(args.stats_file, 'w') as f:
      arg_dict = vars(args)
      arg_dict.update(eval_stats)
      json.dump(arg_dict, f, cls=utils.NumpyEncoder)
    with open(args.epoch_stats_file, 'w') as f:
      json.dump(epoch_stats, f, cls=utils.NumpyEncoder)
    logger.info(utils.dict_to_log_string(eval_stats))
    logger.info('Training of Final Model Complete! Save dir: ' + str(args.save))
Esempio n. 7
0
def main():
  if not torch.cuda.is_available():
    logging.info('no gpu device available')
    sys.exit(1)

  np.random.seed(args.seed)
  torch.cuda.set_device(args.gpu)
  cudnn.benchmark = True

  if not args.random:
    # We would always get the same random architecture if we set the random
    # seed here. We'll set it after finding a random genotype.
    torch.manual_seed(args.seed)

  cudnn.enabled=True
  torch.cuda.manual_seed(args.seed)
  logging.info('gpu device = %d' % args.gpu)
  logging.info("args = %s", args)

  train_data, OUTPUT_DIM, IN_CHANNELS, is_regression = load_dataset(args, train=True)

  criterion = nn.CrossEntropyLoss() if not is_regression else nn.MSELoss()

  if args.random:
    model_tmp = Network(C=args.init_channels, num_classes=OUTPUT_DIM, layers=args.layers, 
                        primitives_name=primitives_name, criterion=criterion, num_channels=IN_CHANNELS)
    genotype = model_tmp.genotype()  # Random

    # We can now set the random seed.
    torch.manual_seed(args.seed)
  # If the architecture was the default DATASET, look up the architecture corresponding to this dataset
  elif args.arch == 'DATASET':
    genotype = GENOTYPE_TBL[dataset]
    print(f'using genotype for {dataset}')
  else:
    try:
      genotype = eval("genotypes.%s" % args.arch)
    except (AttributeError, SyntaxError):
      genotype = genotypes.load_genotype_from_file(args.arch)

  genotypes.save_genotype_to_file(genotype, os.path.join(args.save, "genotype.arch"))
  # Set the inference network; default is NetworkCifar10; supported alternatives NetworkGalaxyZoo
  if dataset == 'GalaxyZoo' and args.gz_dtree:
    model = NetworkGalaxyZoo(C=args.init_channels, num_classes=OUTPUT_DIM, layers=args.layers, genotype=genotype,
                             fc1_size=args.fc1_size, fc2_size=args.fc2_size, num_channels=IN_CHANNELS)
  else:
    model = NetworkCIFAR(args.init_channels, OUTPUT_DIM, args.layers, args.auxiliary, genotype, num_channels=IN_CHANNELS)
  model = model.cuda()

  logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

  criterion = criterion.cuda()

  # build optimizer based on optimizer input; one of SGD or Adam
  if args.optimizer == 'SGD':
    optimizer = torch.optim.SGD(
      model.parameters(),
      args.learning_rate,
      momentum=args.momentum,
      weight_decay=args.weight_decay)
  elif args.optimizer == 'Adam':
    optimizer= torch.optim.Adam(
    params=model.parameters(),
    lr=args.learning_rate,
    betas=(0.90, 0.999),
    weight_decay=args.weight_decay)
  else:
    raise ValueError(f"Bad optimizer; got {args.optimizer}, must be one of 'SGD' or 'Adam'.")

  # Split training data into training and validation queues
  num_train = len(train_data)
  indices = list(range(num_train))
  split = int(np.floor(args.train_portion * num_train))

  train_queue = torch.utils.data.DataLoader(
      train_data, batch_size=args.batch_size,
      sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
      pin_memory=True, num_workers=2)

  valid_queue = torch.utils.data.DataLoader(
      train_data, batch_size=args.batch_size,
      sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
      pin_memory=True, num_workers=2)

  # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs))

  # history of training and validation loss; 2 columns for loss and accuracy / R2
  hist_trn = np.zeros((args.epochs, 2))
  hist_val = np.zeros((args.epochs, 2))
  metric_name = 'accuracy' if not is_regression else 'R2'

  for epoch in range(args.epochs):
    # scheduler.step()
    # logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
    logging.info('epoch %d lr %e', epoch, args.learning_rate)
    # model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
    model.drop_path_prob = args.drop_path_prob

    # training results
    train_acc, train_obj = train(train_queue, model, criterion, optimizer, is_regression=is_regression)
    logging.info(f'training loss; {metric_name}: {train_obj:e} {train_acc:f}')
    # save history to numpy arrays
    hist_trn[epoch] = [train_acc, train_obj]
    np.save(os.path.join(args.save, 'hist_trn'), hist_trn)

    # validation results
    valid_acc, valid_obj = infer(valid_queue, model, criterion, is_regression=is_regression)
    logging.info(f'validation loss; {metric_name}: {valid_obj:e} {valid_acc:f}')
    # save history to numpy arrays
    hist_val[epoch] = [valid_acc, valid_obj]
    np.save(os.path.join(args.save, 'hist_val'), hist_val)

    # save current model weights
    utils.save(model, os.path.join(args.save, 'weights.pt'))
                                                 genotype=None)
elif args.dataset == 'imagenet':
    cnn_model = NetworkImageNet(args.init_channels,
                                classes,
                                args.layers,
                                args.auxiliary,
                                genotype,
                                op_dict=op_dict,
                                C_mid=args.mid_channels)
    # workaround for graph generation limitations
    cnn_model.drop_path_prob = torch.zeros(1)
else:
    cnn_model = NetworkCIFAR(args.init_channels,
                             classes,
                             args.layers,
                             args.auxiliary,
                             genotype,
                             op_dict=op_dict,
                             C_mid=args.mid_channels)
    # workaround for graph generation limitations
    cnn_model.drop_path_prob = torch.zeros(1)

transforms = [
    hl.transforms.Fold('MaxPool3x3 > Conv1x1 > BatchNorm', 'ResizableMaxPool',
                       'ResizableMaxPool'),
    hl.transforms.Fold('MaxPool > Conv > BatchNorm', 'ResizableMaxPool',
                       'ResizableMaxPool'),
    hl.transforms.Fold('Relu > Conv > Conv > BatchNorm', 'ReluSepConvBn'),
    hl.transforms.Fold('ReluSepConvBn > ReluSepConvBn', 'SharpSepConv',
                       'SharpSepConv'),
    hl.transforms.Fold('Relu > Conv > BatchNorm', 'ReLUConvBN'),
Esempio n. 9
0
def main():
    global best_top1, args, logger

    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    # commented because it is now set as an argparse param.
    # args.gpu = 0
    args.world_size = 1

    if args.distributed:
        args.gpu = args.local_rank % torch.cuda.device_count()
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    # note the gpu is used for directory creation and log files
    # which is needed when run as multiple processes
    args = utils.initialize_files_and_args(args)
    logger = utils.logging_setup(args.log_file_path)

    if args.fp16:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."

    if args.static_loss_scale != 1.0:
        if not args.fp16:
            logger.info(
                "Warning:  if --fp16 is not used, static_loss_scale will be ignored."
            )

    # # load the correct ops dictionary
    op_dict_to_load = "operations.%s" % args.ops
    logger.info('loading op dict: ' + str(op_dict_to_load))
    op_dict = eval(op_dict_to_load)

    # load the correct primitives list
    primitives_to_load = "genotypes.%s" % args.primitives
    logger.info('loading primitives:' + primitives_to_load)
    primitives = eval(primitives_to_load)
    logger.info('primitives: ' + str(primitives))
    # create model
    genotype = eval("genotypes.%s" % args.arch)
    # get the number of output channels
    classes = dataset.class_dict[args.dataset]
    # create the neural network
    if args.dataset == 'imagenet':
        model = NetworkImageNet(args.init_channels,
                                classes,
                                args.layers,
                                args.auxiliary,
                                genotype,
                                op_dict=op_dict,
                                C_mid=args.mid_channels)
        flops_shape = [1, 3, 224, 224]
    else:
        model = NetworkCIFAR(args.init_channels,
                             classes,
                             args.layers,
                             args.auxiliary,
                             genotype,
                             op_dict=op_dict,
                             C_mid=args.mid_channels)
        flops_shape = [1, 3, 32, 32]
    model.drop_path_prob = 0.0
    # if args.pretrained:
    #     logger.info("=> using pre-trained model '{}'".format(args.arch))
    #     model = models.__dict__[args.arch](pretrained=True)
    # else:
    #     logger.info("=> creating model '{}'".format(args.arch))
    #     model = models.__dict__[args.arch]()

    if args.flops:
        model = model.cuda()
        logger.info("param size = %fMB", utils.count_parameters_in_MB(model))
        logger.info("flops_shape = " + str(flops_shape))
        logger.info("flops = " +
                    utils.count_model_flops(model, data_shape=flops_shape))
        return

    if args.sync_bn:
        import apex
        logger.info("using apex synced BN")
        model = apex.parallel.convert_syncbn_model(model)

    model = model.cuda()
    if args.fp16:
        model = network_to_half(model)
    if args.distributed:
        # By default, apex.parallel.DistributedDataParallel overlaps communication with
        # computation in the backward pass.
        # model = DDP(model)
        # delay_allreduce delays all communication to the end of the backward pass.
        model = DDP(model, delay_allreduce=True)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    # Scale learning rate based on global batch size
    args.learning_rate = args.learning_rate * float(
        args.batch_size * args.world_size) / 256.
    init_lr = args.learning_rate / args.warmup_lr_divisor
    optimizer = torch.optim.SGD(model.parameters(),
                                init_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # epoch_count = args.epochs - args.start_epoch
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(epoch_count))
    # scheduler = warmup_scheduler.GradualWarmupScheduler(
    #     optimizer, args.warmup_lr_divisor, args.warmup_epochs, scheduler)

    if args.fp16:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.static_loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale)

    # Optionally resume from a checkpoint
    if args.resume or args.evaluate:
        if args.evaluate:
            args.resume = args.evaluate
        # Use a local scope to avoid dangling references
        def resume():
            if os.path.isfile(args.resume):
                logger.info("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(
                    args.resume,
                    map_location=lambda storage, loc: storage.cuda(args.gpu))
                args.start_epoch = checkpoint['epoch']
                if 'best_top1' in checkpoint:
                    best_top1 = checkpoint['best_top1']
                model.load_state_dict(checkpoint['state_dict'])
                # An FP16_Optimizer instance's state dict internally stashes the master params.
                optimizer.load_state_dict(checkpoint['optimizer'])
                # TODO(ahundt) make sure scheduler loading isn't broken
                if 'lr_scheduler' in checkpoint:
                    scheduler.load_state_dict(checkpoint['lr_scheduler'])
                elif 'lr_schedule' in checkpoint:
                    lr_schedule = checkpoint['lr_schedule']
                logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
            else:
                logger.info("=> no checkpoint found at '{}'".format(
                    args.resume))

        resume()

    # # Data loading code
    # traindir = os.path.join(args.data, 'train')
    # valdir = os.path.join(args.data, 'val')

    # if(args.arch == "inception_v3"):
    #     crop_size = 299
    #     val_size = 320 # I chose this value arbitrarily, we can adjust.
    # else:
    #     crop_size = 224
    #     val_size = 256

    # train_dataset = datasets.ImageFolder(
    #     traindir,
    #     transforms.Compose([
    #         transforms.RandomResizedCrop(crop_size),
    #         transforms.RandomHorizontalFlip(),
    #         autoaugment.ImageNetPolicy(),
    #         # transforms.ToTensor(),  # Too slow, moved to data_prefetcher()
    #         # normalize,
    #     ]))
    # val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
    #         transforms.Resize(val_size),
    #         transforms.CenterCrop(crop_size)
    #     ]))

    # train_sampler = None
    # val_sampler = None
    # if args.distributed:
    #     train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    #     val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)

    # train_loader = torch.utils.data.DataLoader(
    #     train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
    #     num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate)

    # val_loader = torch.utils.data.DataLoader(
    #     val_dataset,
    #     batch_size=args.batch_size, shuffle=False,
    #     num_workers=args.workers, pin_memory=True,
    #     sampler=val_sampler,
    #     collate_fn=fast_collate)

    # Get preprocessing functions (i.e. transforms) to apply on data
    # normalize_as_tensor = False because we normalize and convert to a
    # tensor in our custom prefetching function, rather than as part of
    # the transform preprocessing list.
    train_transform, valid_transform = utils.get_data_transforms(
        args, normalize_as_tensor=False)
    # Get the training queue, select training and validation from training set
    train_loader, val_loader = dataset.get_training_queues(
        args.dataset,
        train_transform,
        valid_transform,
        args.data,
        args.batch_size,
        train_proportion=1.0,
        collate_fn=fast_collate,
        distributed=args.distributed,
        num_workers=args.workers)

    if args.evaluate:
        if args.dataset == 'cifar10':
            # evaluate best model weights on cifar 10.1
            # https://github.com/modestyachts/CIFAR-10.1
            train_transform, valid_transform = utils.get_data_transforms(args)
            # Get the training queue, select training and validation from training set
            # Get the training queue, use full training and test set
            train_queue, valid_queue = dataset.get_training_queues(
                args.dataset,
                train_transform,
                valid_transform,
                args.data,
                args.batch_size,
                train_proportion=1.0,
                search_architecture=False)
            test_data = cifar10_1.CIFAR10_1(root=args.data,
                                            download=True,
                                            transform=valid_transform)
            test_queue = torch.utils.data.DataLoader(
                test_data,
                batch_size=args.batch_size,
                shuffle=False,
                pin_memory=True,
                num_workers=args.workers)
            eval_stats = evaluate(args,
                                  model,
                                  criterion,
                                  train_queue=train_queue,
                                  valid_queue=valid_queue,
                                  test_queue=test_queue)
            with open(args.stats_file, 'w') as f:
                # TODO(ahundt) fix "TypeError: 1869 is not JSON serializable" to include arg info, see train.py
                # arg_dict = vars(args)
                # arg_dict.update(eval_stats)
                # json.dump(arg_dict, f)
                json.dump(eval_stats, f)
            logger.info("flops = " + utils.count_model_flops(model))
            logger.info(utils.dict_to_log_string(eval_stats))
            logger.info('\nEvaluation of Loaded Model Complete! Save dir: ' +
                        str(args.save))
        else:
            validate(val_loader, model, criterion, args)
        return

    lr_schedule = cosine_power_annealing(
        epochs=args.epochs,
        max_lr=args.learning_rate,
        min_lr=args.learning_rate_min,
        warmup_epochs=args.warmup_epochs,
        exponent_order=args.lr_power_annealing_exponent_order,
        restart_lr=args.restart_lr)
    epochs = np.arange(args.epochs) + args.start_epoch

    stats_csv = args.epoch_stats_file
    stats_csv = stats_csv.replace('.json', '.csv')
    with tqdm(epochs,
              dynamic_ncols=True,
              disable=args.local_rank != 0,
              leave=False) as prog_epoch:
        best_stats = {}
        stats = {}
        epoch_stats = []
        best_epoch = 0
        for epoch, learning_rate in zip(prog_epoch, lr_schedule):
            if args.distributed and train_loader.sampler is not None:
                train_loader.sampler.set_epoch(int(epoch))
            # if args.distributed:
            # train_sampler.set_epoch(epoch)
            # update the learning rate
            for param_group in optimizer.param_groups:
                param_group['lr'] = learning_rate
            # scheduler.step()
            model.drop_path_prob = args.drop_path_prob * float(epoch) / float(
                args.epochs)
            # train for one epoch
            train_stats = train(train_loader, model, criterion, optimizer,
                                int(epoch), args)
            if args.prof:
                break
            # evaluate on validation set
            top1, val_stats = validate(val_loader, model, criterion, args)
            stats.update(train_stats)
            stats.update(val_stats)
            # stats['lr'] = '{0:.5f}'.format(scheduler.get_lr()[0])
            stats['lr'] = '{0:.5f}'.format(learning_rate)
            stats['epoch'] = epoch

            # remember best top1 and save checkpoint
            if args.local_rank == 0:
                is_best = top1 > best_top1
                best_top1 = max(top1, best_top1)
                stats['best_top1'] = '{0:.3f}'.format(best_top1)
                if is_best:
                    best_epoch = epoch
                    best_stats = copy.deepcopy(stats)
                stats['best_epoch'] = best_epoch

                stats_str = utils.dict_to_log_string(stats)
                logger.info(stats_str)
                save_checkpoint(
                    {
                        'epoch': epoch,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'best_top1': best_top1,
                        'optimizer': optimizer.state_dict(),
                        # 'lr_scheduler': scheduler.state_dict()
                        'lr_schedule': lr_schedule,
                        'stats': best_stats
                    },
                    is_best,
                    path=args.save)
                prog_epoch.set_description(
                    'Overview ***** best_epoch: {0} best_valid_top1: {1:.2f} ***** Progress'
                    .format(best_epoch, best_top1))
            epoch_stats += [copy.deepcopy(stats)]
            with open(args.epoch_stats_file, 'w') as f:
                json.dump(epoch_stats, f, cls=utils.NumpyEncoder)
            utils.list_of_dicts_to_csv(stats_csv, epoch_stats)
        stats_str = utils.dict_to_log_string(best_stats, key_prepend='best_')
        logger.info(stats_str)
        with open(args.stats_file, 'w') as f:
            arg_dict = vars(args)
            arg_dict.update(best_stats)
            json.dump(arg_dict, f, cls=utils.NumpyEncoder)
        with open(args.epoch_stats_file, 'w') as f:
            json.dump(epoch_stats, f, cls=utils.NumpyEncoder)
        utils.list_of_dicts_to_csv(stats_csv, epoch_stats)
        logger.info('Training of Final Model Complete! Save dir: ' +
                    str(args.save))
Esempio n. 10
0
    def __init__(self,
                 test_args: Namespace,
                 my_dataset: MyDataset,
                 model: nn.Module = None):

        self.__device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        log_format = '%(asctime)s %(message)s'
        logging.basicConfig(stream=sys.stdout,
                            level=logging.INFO,
                            format=log_format,
                            datefmt='%m/%d %I:%M:%S %p')
        np.random.seed(test_args.seed)
        torch.manual_seed(test_args.seed)
        cudnn.benchmark = True
        cudnn.enabled = True

        logging.info(f'gpu device = {test_args.gpu}')
        logging.info(f'args = {test_args}')

        if model is None:
            # equal to: genotype = genotypes.DARTS_v2
            if not (test_args.arch or test_args.arch_path):
                logging.info('need to designate arch.')
                sys.exit(1)

            genotype = eval(
                f'genotypes.{test_args.arch}'
            ) if not test_args.arch_path else utils.load_genotype(
                test_args.arch_path)
            print('Load genotype:', genotype)

            if my_dataset is MyDataset.CIFAR10:
                model = NetworkCIFAR(test_args.init_ch, 10, test_args.layers,
                                     test_args.auxiliary,
                                     genotype).to(self.__device)
            elif my_dataset is MyDataset.CIFAR100:
                model = NetworkCIFAR(test_args.init_ch, 100, test_args.layers,
                                     test_args.auxiliary,
                                     genotype).to(self.__device)
            elif my_dataset is MyDataset.ImageNet:
                model = NetworkImageNet(test_args.init_ch, 1000,
                                        test_args.layers, test_args.auxiliary,
                                        genotype).to(self.__device)
            else:
                raise Exception('No match MyDataset')

            utils.load(model, test_args.model_path, False)
            model = model.to(self.__device)

            param_size = utils.count_parameters_in_MB(model)
            logging.info(f'param size = {param_size}MB')

        model.drop_path_prob = test_args.drop_path_prob
        self.__model = model

        self.__args = test_args
        self.__criterion = nn.CrossEntropyLoss().to(self.__device)

        if my_dataset is MyDataset.CIFAR10:
            _, test_transform = utils._data_transforms_cifar10(test_args)
            test_data = dset.CIFAR10(root=test_args.data,
                                     train=False,
                                     download=True,
                                     transform=test_transform)

        elif my_dataset is MyDataset.CIFAR100:
            _, test_transform = utils._data_transforms_cifar100(test_args)
            test_data = dset.CIFAR100(root=test_args.data,
                                      train=False,
                                      download=True,
                                      transform=test_transform)

        elif my_dataset is MyDataset.ImageNet:
            validdir = test_args.data / 'val'
            normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])
            valid_data = dset.ImageFolder(
                validdir,
                transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    normalize,
                ]))
            test_data = valid_data
        else:
            raise Exception('No match MyDataset')

        self.__test_queue = torch.utils.data.DataLoader(
            test_data,
            batch_size=test_args.batchsz,
            shuffle=False,
            pin_memory=True,
            num_workers=4)
Esempio n. 11
0
    def __init__(self,
                 args: Namespace,
                 genotype: Genotype,
                 my_dataset: MyDataset,
                 choose_cell=False):

        self.__args = args
        self.__dataset = my_dataset
        self.__previous_epochs = 0

        if args.seed is None:
            raise Exception('designate seed.')
        elif args.epochs is None:
            raise Exception('designate epochs.')
        if not (args.arch or args.arch_path):
            raise Exception('need to designate arch.')

        log_format = '%(asctime)s %(message)s'
        logging.basicConfig(stream=sys.stdout,
                            level=logging.INFO,
                            format=log_format,
                            datefmt='%m/%d %I:%M:%S %p')
        fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
        fh.setFormatter(logging.Formatter(log_format))
        logging.getLogger().addHandler(fh)
        np.random.seed(args.seed)
        cudnn.benchmark = True
        cudnn.enabled = True
        torch.manual_seed(args.seed)

        logging.info(f'gpu device = {args.gpu}')
        logging.info(f'args = {args}')

        logging.info(f'Train genotype: {genotype}')

        if my_dataset == MyDataset.CIFAR10:
            self.model = NetworkCIFAR(args.init_ch, 10, args.layers,
                                      args.auxiliary, genotype)
            train_transform, valid_transform = utils._data_transforms_cifar10(
                args)
            train_data = dset.CIFAR10(root=args.data,
                                      train=True,
                                      download=True,
                                      transform=train_transform)
            valid_data = dset.CIFAR10(root=args.data,
                                      train=False,
                                      download=True,
                                      transform=valid_transform)

        elif my_dataset == MyDataset.CIFAR100:
            self.model = NetworkCIFAR(args.init_ch, 100, args.layers,
                                      args.auxiliary, genotype)
            train_transform, valid_transform = utils._data_transforms_cifar100(
                args)
            train_data = dset.CIFAR100(root=args.data,
                                       train=True,
                                       download=True,
                                       transform=train_transform)
            valid_data = dset.CIFAR100(root=args.data,
                                       train=False,
                                       download=True,
                                       transform=valid_transform)

        elif my_dataset == MyDataset.ImageNet:
            self.model = NetworkImageNet(args.init_ch, 1000, args.layers,
                                         args.auxiliary, genotype)
            self.__criterion_smooth = CrossEntropyLabelSmooth(
                1000, args.label_smooth).to(device)
            traindir = os.path.join(args.data, 'train')
            validdir = os.path.join(args.data, 'val')
            normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])
            train_data = dset.ImageFolder(
                traindir,
                transforms.Compose([
                    transforms.RandomResizedCrop(224),
                    transforms.RandomHorizontalFlip(),
                    transforms.ColorJitter(brightness=0.4,
                                           contrast=0.4,
                                           saturation=0.4,
                                           hue=0.2),
                    transforms.ToTensor(),
                    normalize,
                ]))
            valid_data = dset.ImageFolder(
                validdir,
                transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    normalize,
                ]))
        else:
            raise Exception('No match Dataset')

        checkpoint = None
        if use_DataParallel:
            print('use Data Parallel')
            if args.checkpoint_path:
                checkpoint = torch.load(args.checkpoint_path)
                utils.load(self.model, checkpoint['state_dict'],
                           args.to_parallel)
                self.__previous_epochs = checkpoint['epoch']
                args.epochs -= self.__previous_epochs
                if args.epochs <= 0:
                    raise Exception('args.epochs is too small.')

            self.model = nn.DataParallel(self.model)
            self.__module = self.model.module
            torch.cuda.manual_seed_all(args.seed)
        else:
            if args.checkpoint_path:
                checkpoint = torch.load(args.checkpoint_path)
                utils.load(self.model, checkpoint['state_dict'],
                           args.to_parallel)
                args.epochs -= checkpoint['epoch']
                if args.epochs <= 0:
                    raise Exception('args.epochs is too small.')
            torch.cuda.manual_seed(args.seed)
            self.__module = self.model

        self.model.to(device)

        param_size = utils.count_parameters_in_MB(self.model)
        logging.info(f'param size = {param_size}MB')

        self.__criterion = nn.CrossEntropyLoss().to(device)

        self.__optimizer = torch.optim.SGD(self.__module.parameters(),
                                           args.lr,
                                           momentum=args.momentum,
                                           weight_decay=args.wd)
        if checkpoint:
            self.__optimizer.load_state_dict(checkpoint['optimizer'])

        num_workers = torch.cuda.device_count() * 4
        if choose_cell:
            num_train = len(train_data)  # 50000
            indices = list(range(num_train))
            split = int(np.floor(args.train_portion * num_train))  # 25000

            self.__train_queue = torch.utils.data.DataLoader(
                train_data,
                batch_size=args.batchsz,
                sampler=torch.utils.data.sampler.SubsetRandomSampler(
                    indices[:split]),
                pin_memory=True,
                num_workers=num_workers)

            self.__valid_queue = torch.utils.data.DataLoader(
                train_data,
                batch_size=args.batchsz,
                sampler=torch.utils.data.sampler.SubsetRandomSampler(
                    indices[split:]),
                pin_memory=True,
                num_workers=num_workers)
        else:
            self.__train_queue = torch.utils.data.DataLoader(
                train_data,
                batch_size=args.batchsz,
                shuffle=True,
                pin_memory=True,
                num_workers=num_workers)

            self.__valid_queue = torch.utils.data.DataLoader(
                valid_data,
                batch_size=args.batchsz,
                shuffle=False,
                pin_memory=True,
                num_workers=num_workers)

        if my_dataset == MyDataset.CIFAR10 or MyDataset.CIFAR100:
            self.__scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                self.__optimizer, args.epochs)
        elif my_dataset == MyDataset.ImageNet:
            self.__scheduler = torch.optim.lr_scheduler.StepLR(
                self.__optimizer, args.decay_period, gamma=args.gamma)
        else:
            raise Exception('No match Dataset')

        if checkpoint:
            self.__scheduler.load_state_dict(checkpoint['scheduler'])