示例#1
0
def main():
    # torch.manual_seed(args.seed)
    # torch.cuda.manual_seed_all(args.seed)
    # np.random.seed(args.seed)

    saver = Saver(args)
    # set log
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p',
                        filename=os.path.join(saver.experiment_dir, 'log.txt'),
                        filemode='w')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    logging.getLogger().addHandler(console)

    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)

    saver.create_exp_dir(scripts_to_save=glob.glob('*.py') +
                         glob.glob('*.sh') + glob.glob('*.yml'))
    saver.save_experiment_config()
    summary = TensorboardSummary(saver.experiment_dir)
    writer = summary.create_summary()
    best_pred = 0

    logging.info(args)

    device = torch.device('cuda')
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(device)
    #
    # ''' Compute FLOPs and Params '''
    # maml = Meta(args, criterion)
    # flops, params = get_model_complexity_info(maml.model, (84, 84), as_strings=False, print_per_layer_stat=True)
    # logging.info('FLOPs: {} MMac Params: {}'.format(flops / 10 ** 6, params))
    #
    # maml = Meta(args, criterion).to(device)
    # tmp = filter(lambda x: x.requires_grad, maml.parameters())
    # num = sum(map(lambda x: np.prod(x.shape), tmp))
    # #logging.info(maml)
    # logging.info('Total trainable tensors: {}'.format(num))

    # batch_size here means total episode number
    mini = MiniImagenet(args.data_path,
                        mode='train',
                        n_way=args.n_way,
                        k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batch_size=args.batch_size,
                        resize=args.img_size,
                        task_id=None)
    mini_test = MiniImagenet(args.data_path,
                             mode='test',
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             batch_size=args.test_batch_size,
                             resize=args.img_size,
                             task_id=args.task_id)
    train_loader = DataLoader(mini,
                              args.meta_batch_size,
                              shuffle=True,
                              num_workers=args.num_workers,
                              pin_memory=True)
    test_loader = DataLoader(mini_test,
                             args.meta_test_batch_size,
                             shuffle=True,
                             num_workers=args.num_workers,
                             pin_memory=True)
    ''' Decoding '''
    model = Network(args,
                    args.init_channels,
                    args.n_way,
                    args.layers,
                    criterion,
                    pretrained=True).cuda()
    inner_optimizer_theta = torch.optim.SGD(model.arch_parameters(),
                                            lr=args.update_lr_theta)
    #inner_optimizer_theta = torch.optim.SGD(model.arch_parameters(), lr=100)
    inner_optimizer_w = torch.optim.SGD(model.parameters(),
                                        lr=args.update_lr_w)

    # load state dict
    pretrained_path = '/data2/dongzelian/NAS/meta_nas/run_meta_nas/mini-imagenet/meta-nas/experiment_21/model_best.pth.tar'
    pretrain_dict = torch.load(pretrained_path)['state_dict_w']
    model_dict = {}
    state_dict = model.state_dict()
    for k, v in pretrain_dict.items():
        if k[6:] in state_dict:
            model_dict[k[6:]] = v
        else:
            print(k)
    state_dict.update(model_dict)
    model.load_state_dict(state_dict)
    #model._arch_parameters = torch.load(pretrained_path)['state_dict_theta']

    for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(test_loader):
        x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                     x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)
        for k in range(args.update_step_test):
            logits = model(x_spt, alphas=model.arch_parameters())
            loss = criterion(logits, y_spt)

            inner_optimizer_w.zero_grad()
            inner_optimizer_theta.zero_grad()
            loss.backward()
            inner_optimizer_w.step()
            inner_optimizer_theta.step()

        genotype = model.genotype()
        logging.info(genotype)
        maml = Meta_decoding(args, criterion, genotype).to(device)
        #exit()
        #print(step)
        #print(genotype)

    for epoch in range(args.epoch):
        logging.info('--------- Epoch: {} ----------'.format(epoch))
        accs_all_train = []
        # # TODO: how to choose batch data to update theta?
        # valid_iterator = iter(train_loader)
        batch_time = utils.AverageMeter()
        data_time = utils.AverageMeter()
        update_w_time = utils.AverageMeter()
        end = time.time()
        for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(train_loader):
            data_time.update(time.time() - end)
            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(
                device), x_qry.to(device), y_qry.to(device)
            # (x_search_spt, y_search_spt, x_search_qry, y_search_qry), valid_iterator = infinite_get(valid_iterator, train_loader)
            # x_search_spt, y_search_spt, x_search_qry, y_search_qry = x_search_spt.to(device), y_search_spt.to(device), x_search_qry.to(device), y_search_qry.to(device)
            accs, update_w_time = maml(x_spt, y_spt, x_qry, y_qry,
                                       update_w_time)
            accs_all_train.append(accs)
            batch_time.update(time.time() - end)
            end = time.time()
            writer.add_scalar('train/acc_iter', accs[-1],
                              step + len(train_loader) * epoch)
            if step % args.report_freq == 0:
                logging.info(
                    'Epoch: [{0}][{1}/{2}]\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                    'W {update_w_time.val:.3f} ({update_w_time.avg:.3f})\t'
                    'training acc: {accs}'.format(epoch,
                                                  step,
                                                  len(train_loader),
                                                  batch_time=batch_time,
                                                  data_time=data_time,
                                                  update_w_time=update_w_time,
                                                  accs=accs))

            if step % args.test_freq == 0:
                test_accs, test_stds, test_ci95 = meta_test(
                    train_loader, test_loader, maml, device, epoch, writer)
                logging.info(
                    '[Epoch: {}]\t Test acc: {}\t Test ci95: {}'.format(
                        epoch, test_accs, test_ci95))

                # Save the best meta model.
                new_pred = test_accs[-1]
                if new_pred > best_pred:
                    is_best = True
                    best_pred = new_pred
                else:
                    is_best = False
                saver.save_checkpoint(
                    {
                        'epoch':
                        epoch + 1,
                        'state_dict':
                        maml.module.state_dict() if isinstance(
                            maml, nn.DataParallel) else maml.state_dict(),
                        'best_pred':
                        best_pred,
                    }, is_best)
示例#2
0
def main():
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    cudnn.enabled = True

    saver = Saver(args)
    # set log
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p',
                        filename=os.path.join(saver.experiment_dir, 'log.txt'),
                        filemode='w')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    logging.getLogger().addHandler(console)

    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    saver.create_exp_dir(scripts_to_save=glob.glob('*.py') +
                         glob.glob('*.sh') + glob.glob('*.yml'))
    saver.save_experiment_config()
    summary = TensorboardSummary(saver.experiment_dir)
    writer = summary.create_summary()
    best_pred = 0

    logging.info(args)

    device = torch.device('cuda')
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(device)
    maml = Meta(args, criterion).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    logging.info(maml)
    logging.info('Total trainable tensors: {}'.format(num))

    # batch_size here means total episode number
    mini = MiniImagenet(args.data_path,
                        mode='train',
                        n_way=args.n_way,
                        k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batch_size=args.batch_size,
                        resize=args.img_size,
                        split=[0, args.train_portion])
    mini_valid = MiniImagenet(args.data_path,
                              mode='train',
                              n_way=args.n_way,
                              k_shot=args.k_spt,
                              k_query=args.k_qry,
                              batch_size=args.batch_size,
                              resize=args.img_size,
                              split=[args.train_portion, 1])
    mini_test = MiniImagenet(args.data_path,
                             mode='train',
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             batch_size=args.test_batch_size,
                             resize=args.img_size,
                             split=[args.train_portion, 1])
    train_queue = DataLoader(mini,
                             args.meta_batch_size,
                             shuffle=True,
                             num_workers=args.num_workers,
                             pin_memory=True)
    valid_queue = DataLoader(mini_valid,
                             args.meta_batch_size,
                             shuffle=True,
                             num_workers=args.num_workers,
                             pin_memory=True)
    test_queue = DataLoader(mini_test,
                            args.meta_test_batch_size,
                            shuffle=True,
                            num_workers=args.num_workers,
                            pin_memory=True)
    architect = Architect(maml.model, args)

    for epoch in range(args.epoch):
        # fetch batch_size num of episode each time
        logging.info('--------- Epoch: {} ----------'.format(epoch))

        train_accs = meta_train(train_queue, valid_queue, maml, architect,
                                device, criterion, epoch, writer)
        logging.info('[Epoch: {}]\t Train acc: {}'.format(epoch, train_accs))
        valid_accs = meta_test(test_queue, maml, device, epoch, writer)
        logging.info('[Epoch: {}]\t Test acc: {}'.format(epoch, valid_accs))

        genotype = maml.model.genotype()
        logging.info('genotype = %s', genotype)

        # logging.info(F.softmax(maml.model.alphas_normal, dim=-1))
        logging.info(F.softmax(maml.model.alphas_reduce, dim=-1))

        # Save the best meta model.
        new_pred = valid_accs[-1]
        if new_pred > best_pred:
            is_best = True
            best_pred = new_pred
        else:
            is_best = False
        saver.save_checkpoint(
            {
                'epoch':
                epoch,
                'state_dict':
                maml.module.state_dict()
                if isinstance(maml, nn.DataParallel) else maml.state_dict(),
                'best_pred':
                best_pred,
            }, is_best)
示例#3
0
def main():
    saver = Saver(args)
    # set log
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p',
                        filename=os.path.join(saver.experiment_dir, 'log.txt'),
                        filemode='w')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    logging.getLogger().addHandler(console)

    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    # set seed
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    cudnn.enabled = True

    # set saver
    saver.create_exp_dir(scripts_to_save=glob.glob('*.py') +
                         glob.glob('*.sh') + glob.glob('*.yml'))
    saver.save_experiment_config()
    summary = TensorboardSummary(saver.experiment_dir)
    writer = summary.create_summary()

    logging.info(args)

    device = torch.device('cuda')
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(device)
    ''' Compute FLOPs and Params '''
    maml = Meta(args, criterion)
    flops, params = get_model_complexity_info(maml.model, (3, 84, 84),
                                              as_strings=False,
                                              print_per_layer_stat=True,
                                              verbose=True)
    logging.info('FLOPs: {} MMac Params: {}'.format(flops / 10**6, params))

    maml = Meta(args, criterion).to(device)
    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    #logging.info(maml)
    logging.info('Total trainable tensors: {}'.format(num))

    # batch_size here means total episode number
    mini = MiniImagenet(args.data_path,
                        mode='train',
                        n_way=args.n_way,
                        k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batch_size=args.batch_size,
                        resize=args.img_size)
    mini_test = MiniImagenet(args.data_path,
                             mode='val',
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             batch_size=args.test_batch_size,
                             resize=args.img_size)
    train_loader = DataLoader(mini,
                              args.meta_batch_size,
                              shuffle=True,
                              num_workers=args.num_workers,
                              pin_memory=True)
    test_loader = DataLoader(mini_test,
                             args.meta_test_batch_size,
                             shuffle=True,
                             num_workers=args.num_workers,
                             pin_memory=True)

    # load pretrained model and inference
    if args.pretrained_model:
        checkpoint = torch.load(args.pretrained_model)
        if isinstance(maml.model, torch.nn.DataParallel):
            maml.module.load_state_dict(checkpoint['state_dict'])
        else:
            maml.load_state_dict(checkpoint['state_dict'])

        if args.evaluate:
            test_accs = meta_test(test_loader, maml, device,
                                  checkpoint['epoch'])
            logging.info('[Epoch: {}]\t Test acc: {}'.format(
                checkpoint['epoch'], test_accs))
            return

    # Start training
    for epoch in range(args.epoch):
        # fetch batch_size num of episode each time
        logging.info('--------- Epoch: {} ----------'.format(epoch))

        train_accs = meta_train(train_loader, maml, device, epoch, writer,
                                test_loader, saver)
        logging.info('[Epoch: {}]\t Train acc: {}'.format(epoch, train_accs))