Example #1
0
    def create_segmenter(encoder, decoder_config):
        with torch.no_grad():
            decoder = Decoder(
                inp_sizes=encoder.out_sizes,
                num_classes=NUM_CLASSES[args.dataset_type][0],
                config=decoder_config,
                agg_size=48,  #args.agg_cell_size, what's the fxxk
                aux_cell=True,  #args.aux_cell,
                repeats=1)  #args.sep_repeats)

        # Fuse encoder and decoder
        segmenter = nn.DataParallel(Segmenter(encoder, decoder)).cuda()
        logger.info(" Created Segmenter, #PARAMS (Total, No AUX)={}".format(
            compute_params(segmenter)))
        return segmenter  #, entropy, log_prob
Example #2
0
    def create_segmenter(encoder):
        with torch.no_grad():
            decoder_config, entropy, log_prob = agent.controller.sample()
            decoder = Decoder(inp_sizes=encoder.out_sizes,
                              num_classes=args.num_classes[0],
                              config=decoder_config,
                              agg_size=args.agg_cell_size,
                              aux_cell=args.aux_cell,
                              repeats=args.sep_repeats)

        # Fuse encoder and decoder
        segmenter = nn.DataParallel(Segmenter(encoder, decoder)).cuda()
        logger.info(" Created Segmenter, #PARAMS (Total, No AUX)={}".format(
            compute_params(segmenter)))
        return segmenter, decoder_config, entropy, log_prob
    def create_segmenter(encoder):
        if args.ctrl_version == "cvpr":
            from nn.micro_decoders import MicroDecoder as Decoder
        elif args.ctrl_version == "wacv":
            from nn.micro_decoders import TemplateDecoder as Decoder
        with torch.no_grad():
            decoder_config, entropy, log_prob = agent.controller.sample()
            decoder = Decoder(
                inp_sizes=encoder.out_sizes,
                num_classes=args.num_classes[0],
                config=decoder_config,
                agg_size=args.agg_cell_size,
                aux_cell=args.aux_cell,
                repeats=args.sep_repeats,
            )

        # Fuse encoder and decoder
        segmenter = nn.DataParallel(Segmenter(encoder, decoder)).cuda()
        logger.info(" Created Segmenter, #PARAMS (Total, No AUX)={}".format(
            compute_params(segmenter)))
        return segmenter, decoder_config, entropy, log_prob
def main():
    # Set-up experiment
    args = get_arguments()
    logger = logging.getLogger(__name__)
    logger.debug(args)
    exp_name = time.strftime("%H_%M_%S")
    dir_name = "{}/{}".format(args.summary_dir, exp_name)
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)
    arch_writer = open("{}/genotypes.out".format(dir_name), "w")
    logger.info(" Running Experiment {}".format(exp_name))
    args.num_tasks = len(args.num_classes)
    segm_crit = nn.NLLLoss2d(ignore_index=255).cuda()

    # Set-up random seeds
    torch.manual_seed(args.random_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.random_seed)
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)

    # Initialise encoder
    encoder = create_encoder(ctrl_version=args.ctrl_version, )
    logger.info(" Loaded Encoder with #TOTAL PARAMS={:3.2f}M".format(
        compute_params(encoder)[0] / 1e6))

    # Generate teacher if any
    kd_net = None
    kd_crit = None
    if args.do_kd:
        from kd.rf_lw.model_lw_v2 import rf_lw152 as kd_model

        kd_crit = nn.MSELoss().cuda()
        kd_net = (kd_model(pretrained=True,
                           num_classes=args.num_classes[0]).cuda().eval())
        logger.info(" Loaded teacher, #TOTAL PARAMS={:3.2f}M".format(
            compute_params(kd_net)[0] / 1e6))

    # Generate controller / RL-agent
    agent = create_agent(
        enc_num_layers=len(encoder.out_sizes),
        num_ops=args.num_ops,
        num_agg_ops=args.num_agg_ops,
        lstm_hidden_size=args.lstm_hidden_size,
        lstm_num_layers=args.lstm_num_layers,
        dec_num_cells=args.dec_num_cells,
        cell_num_layers=args.cell_num_layers,
        cell_max_repeat=args.cell_max_repeat,
        cell_max_stride=args.cell_max_stride,
        ctrl_lr=args.ctrl_lr,
        ctrl_baseline_decay=args.ctrl_baseline_decay,
        ctrl_agent=args.ctrl_agent,
        ctrl_version=args.ctrl_version,
    )
    logger.info(" Loaded Controller, #TOTAL PARAMS={:3.2f}M".format(
        compute_params(agent.controller)[0] / 1e6))

    def create_segmenter(encoder):
        if args.ctrl_version == "cvpr":
            from nn.micro_decoders import MicroDecoder as Decoder
        elif args.ctrl_version == "wacv":
            from nn.micro_decoders import TemplateDecoder as Decoder
        with torch.no_grad():
            decoder_config, entropy, log_prob = agent.controller.sample()
            decoder = Decoder(
                inp_sizes=encoder.out_sizes,
                num_classes=args.num_classes[0],
                config=decoder_config,
                agg_size=args.agg_cell_size,
                aux_cell=args.aux_cell,
                repeats=args.sep_repeats,
            )

        # Fuse encoder and decoder
        segmenter = nn.DataParallel(Segmenter(encoder, decoder)).cuda()
        logger.info(" Created Segmenter, #PARAMS (Total, No AUX)={}".format(
            compute_params(segmenter)))
        return segmenter, decoder_config, entropy, log_prob

    # Sample first configuration
    segmenter, decoder_config, entropy, log_prob = create_segmenter(encoder)
    del encoder

    # Create dataloaders
    train_loader, val_loader, do_search = create_loaders(args)

    # Initialise task performance measurers
    task_ps = [[
        TaskPerformer(maxval=0.01, delta=0.9)
        for _ in range(args.num_segm_epochs[idx] // args.val_every[idx])
    ] for idx, _ in enumerate(range(args.num_tasks))]

    # Restore from previous checkpoint if any
    best_val, epoch_start = load_ckpt(args.ckpt_path, {"agent": agent})

    # Saver: keeping checkpoint with best validation score (a.k.a best reward)
    saver = Saver(
        args=vars(args),
        ckpt_dir=args.snapshot_dir,
        best_val=best_val,
        condition=lambda x, y: x > y,
    )

    logger.info(" Pre-computing data for task0")
    Xy_train = populate_task0(segmenter, train_loader, kd_net, args.n_task0,
                              args.do_kd)
    if args.do_kd:
        del kd_net

    logger.info(" Training Process Starts")
    for epoch in range(epoch_start, args.num_epochs):
        reward = 0.0
        start = time.time()
        torch.cuda.empty_cache()
        logger.info(" Training Segmenter, Arch {}".format(str(epoch)))
        stop = False
        for task_idx in range(args.num_tasks):
            if stop:
                break
            torch.cuda.empty_cache()
            # Change dataloader
            train_loader.batch_sampler.batch_size = args.batch_size[task_idx]
            for loader in [train_loader, val_loader]:
                try:
                    loader.dataset.set_config(
                        crop_size=args.crop_size[task_idx],
                        shorter_side=args.shorter_side[task_idx],
                    )
                except AttributeError:
                    # for subset
                    loader.dataset.dataset.set_config(
                        crop_size=args.crop_size[task_idx],
                        resize_side=args.resize_side[task_idx],
                    )

            logger.info(" Training Task {}".format(str(task_idx)))
            # Optimisers
            optim_enc, optim_dec = create_optimisers(
                args.enc_optim,
                args.dec_optim,
                args.enc_lr[task_idx],
                args.dec_lr[task_idx],
                args.enc_mom[task_idx],
                args.dec_mom[task_idx],
                args.enc_wd[task_idx],
                args.dec_wd[task_idx],
                segmenter.module.encoder.parameters(),
                segmenter.module.decoder.parameters(),
            )
            avg_param = init_polyak(
                args.do_polyak,
                segmenter.module.decoder if task_idx == 0 else segmenter)
            for epoch_segm in range(args.num_segm_epochs[task_idx]):
                if task_idx == 0:
                    train_task0(
                        Xy_train,
                        segmenter,
                        optim_dec,
                        epoch_segm,
                        segm_crit,
                        kd_crit,
                        args.batch_size[0],
                        args.freeze_bn[0],
                        args.do_kd,
                        args.kd_coeff,
                        args.dec_grad_clip,
                        args.do_polyak,
                        avg_param=avg_param,
                        polyak_decay=0.9,
                        aux_weight=args.dec_aux_weight,
                    )
                else:
                    train_segmenter(
                        segmenter,
                        train_loader,
                        optim_enc,
                        optim_dec,
                        epoch_segm,
                        segm_crit,
                        args.freeze_bn[1],
                        args.enc_grad_clip,
                        args.dec_grad_clip,
                        args.do_polyak,
                        args.print_every,
                        aux_weight=args.dec_aux_weight,
                        avg_param=avg_param,
                        polyak_decay=0.99,
                    )
                apply_polyak(
                    args.do_polyak,
                    segmenter.module.decoder if task_idx == 0 else segmenter,
                    avg_param,
                )
                if (epoch_segm + 1) % (args.val_every[task_idx]) == 0:
                    logger.info(
                        " Validating Segmenter, Arch {}, Task {}".format(
                            str(epoch), str(task_idx)))
                    task_miou = validate(
                        segmenter,
                        val_loader,
                        epoch,
                        epoch_segm,
                        num_classes=args.num_classes[task_idx],
                        print_every=args.print_every,
                        omit_classes=args.val_omit_classes,
                    )
                    # Verifying if we are continuing training this architecture.
                    c_task_ps = task_ps[task_idx][(epoch_segm + 1) //
                                                  args.val_every[task_idx] - 1]
                    if c_task_ps.step(task_miou):
                        continue
                    else:
                        logger.info(" Interrupting")
                        stop = True
                        break
            reward = task_miou
        if do_search:
            logger.info(" Training Controller")
            sample = ((decoder_config), reward, entropy, log_prob)
            train_agent(agent, sample)
            # Log this epoch
            _, params = compute_params(segmenter)
            logger.info(" Decoder: {}".format(decoder_config))
            # Save controller params
            saver.save(reward, {
                "agent": agent.state_dict(),
                "epoch": epoch
            }, logger)
            # Save genotypes
            epoch_time = (time.time() - start) / sum(
                args.num_segm_epochs[:(task_idx + 1)])
            arch_writer.write(
                "reward: {:.4f}, epoch: {}, params: {}, epoch_time: {:.4f}, genotype: {}\n"
                .format(reward, epoch, params, epoch_time, decoder_config))
            arch_writer.flush()
            # Sample a new architecture
            del segmenter
            encoder = create_encoder(ctrl_version=args.ctrl_version, )
            segmenter, decoder_config, entropy, log_prob = create_segmenter(
                encoder)
            del encoder
Example #5
0
def main():
    # Set-up experiment
    args = get_arguments()
    logger = logging.getLogger(__name__)
    exp_name = time.strftime('%H_%M_%S')
    # dir_name = '{}/{}'.format(args.summary_dir, exp_name)
    # if not os.path.exists(dir_name):
    #     os.makedirs(dir_name)
    # arch_writer = open('{}/genotypes.out'.format(dir_name), 'w')
    logger.info(" Running Experiment {}".format(exp_name))
    args.num_tasks = len(NUM_CLASSES[args.dataset_type])
    segm_crit = nn.NLLLoss2d(ignore_index=255).cuda()
    # Set-up random seeds
    torch.manual_seed(args.random_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.random_seed)
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)

    # Create dataloaders
    train_loader, val_loader, do_search = create_loaders(args)

    def create_segmenter(encoder, decoder_config):
        with torch.no_grad():
            decoder = Decoder(
                inp_sizes=encoder.out_sizes,
                num_classes=NUM_CLASSES[args.dataset_type][0],
                config=decoder_config,
                agg_size=48,  #args.agg_cell_size, what's the fxxk
                aux_cell=True,  #args.aux_cell,
                repeats=1)  #args.sep_repeats)

        # Fuse encoder and decoder
        segmenter = nn.DataParallel(Segmenter(encoder, decoder)).cuda()
        logger.info(" Created Segmenter, #PARAMS (Total, No AUX)={}".format(
            compute_params(segmenter)))
        return segmenter  #, entropy, log_prob

    for decoder_config in decoder_config_arry:
        # Initialise encoder
        encoder = create_encoder()
        logger.info(" Loaded Encoder with #TOTAL PARAMS={:3.2f}M".format(
            compute_params(encoder)[0] / 1e6))
        # Sample first configuration
        segmenter = create_segmenter(encoder, decoder_config)
        del encoder

        logger.info(" Loaded Encoder with #TOTAL PARAMS={:3.2f}M".format(
            compute_params(segmenter)[0] / 1e6))

        # Saver: keeping checkpoint with best validation score (a.k.a best reward)
        now = datetime.datetime.now()

        snapshot_dir = args.snapshot_dir + '_train_' + args.dataset_type + "_{:%Y%m%dT%H%M}".format(
            now)
        seg_saver = seg_Saver(ckpt_dir=snapshot_dir)

        arch_writer = open('{}/genotypes.out'.format(snapshot_dir), 'w')
        arch_writer.write('genotype: {}\n'.format(decoder_config))
        arch_writer.flush()

        logger.info(" Pre-computing data for task0")
        kd_net = None  # stub the kd

        logger.info(" Training Process Starts")
        for task_idx in range(args.num_tasks):  #0,1
            if task_idx == 0:
                continue
            torch.cuda.empty_cache()
            # Change dataloader
            train_loader.batch_sampler.batch_size = BATCH_SIZE[
                args.dataset_type][task_idx]

            logger.info(" Training Task {}".format(str(task_idx)))
            # Optimisers
            optim_enc, optim_dec = create_optimisers(
                args.optim_enc, args.optim_dec, args.lr_enc[task_idx],
                args.lr_dec[task_idx], args.mom_enc[task_idx],
                args.mom_dec[task_idx], args.wd_enc[task_idx],
                args.wd_dec[task_idx], segmenter.module.encoder.parameters(),
                segmenter.module.decoder.parameters())
            kd_crit = None  #stub the kd
            for epoch_segm in range(TRAIN_EPOCH_NUM[args.dataset_type]
                                    [task_idx]):  # [5,1] [20,8]
                final_loss = train_segmenter(
                    segmenter,  #train the segmenter end to end onece
                    train_loader,
                    optim_enc,
                    optim_dec,
                    epoch_segm,
                    segm_crit,
                    args.freeze_bn[1],
                    args.enc_grad_clip,
                    args.dec_grad_clip,
                    args.do_polyak,
                    args.print_every,
                    aux_weight=args.dec_aux_weight,
                    # avg_param=avg_param,
                    polyak_decay=0.99)
        seg_saver.save(final_loss, segmenter.state_dict(), logger)  #stub to 1
        # validat
        segmenter.eval()
        data_file = dataset_dirs[args.dataset_type]['VAL_LIST']
        data_dir = dataset_dirs[args.dataset_type]['VAL_DIR']
        with open(data_file, 'rb') as f:
            datalist = f.readlines()
        try:
            datalist = [
                (k, v) for k, v, _ in \
                map(lambda x: x.decode('utf-8').strip('\n').split('\t'), datalist)]
        except ValueError:  # Adhoc for test.
            datalist = [
                (k, k)
                for k in map(lambda x: x.decode('utf-8').strip('\n'), datalist)
            ]
        imgs_all = [
            os.path.join(data_dir, datalist[i][0])
            for i in range(0, len(datalist))
        ]
        msks_all = [
            os.path.join(data_dir, datalist[i][1])
            for i in range(0, len(datalist))
        ]
        validate_output_dir = os.path.join(
            dataset_dirs[args.dataset_type]['VAL_DIR'], 'validate_output')
        validate_gt_dir = os.path.join(
            dataset_dirs[args.dataset_type]['VAL_DIR'], 'validate_gt')
        if not os.path.exists(validate_output_dir):
            os.makedirs(validate_output_dir)
        else:
            shutil.rmtree(validate_output_dir)
            os.makedirs(validate_output_dir)

        if not os.path.exists(validate_gt_dir):
            os.makedirs(validate_gt_dir)
        else:
            shutil.rmtree(validate_gt_dir)
            os.makedirs(validate_gt_dir)
        # validate_color_dir = os.path.join(dataset_dirs[args.dataset_type]['VAL_DIR'], 'validate_output_color')
        for i, img_path in enumerate(imgs_all):
            # logger.info("Testing image:{}".format(img_path))
            img = np.array(Image.open(img_path))
            msk = np.array(Image.open(msks_all[i]))
            orig_size = img.shape[:2][::-1]

            img_inp = torch.tensor(prepare_img(img).transpose(
                2, 0, 1)[None]).float().to(device)
            segm = segmenter(
                img_inp)[0].squeeze().data.cpu().numpy().transpose(
                    (1, 2, 0))  # 47*63*21
            if args.dataset_type == 'celebA':
                # msk = cv2.resize(msk,segm.shape[0:2],interpolation=cv2.INTER_NEAREST)
                segm = cv2.resize(segm,
                                  orig_size,
                                  interpolation=cv2.INTER_CUBIC)  # 375*500*21
            else:
                segm = cv2.resize(segm,
                                  orig_size,
                                  interpolation=cv2.INTER_CUBIC)  # 375*500*21
            segm = segm.argmax(axis=2).astype(np.uint8)

            image_name = img_path.split('/')[-1].split('.')[0]
            # image_name = val_loader.dataset.datalist[i][0].split('/')[1].split('.')[0]
            # cv2.imwrite(os.path.join(validate_color_dir, "{}.png".format(image_name)), color_array[segm])
            # cv2.imwrite(os.path.join(validate_gt_dir, "{}.png".format(image_name)), color_array[msk])
            cv2.imwrite(
                os.path.join(validate_output_dir, "{}.png".format(image_name)),
                segm)
            cv2.imwrite(
                os.path.join(validate_gt_dir, "{}.png".format(image_name)),
                msk)

        if args.dataset_type == 'celebA':
            cal_f1_score_celebA(validate_gt_dir, validate_output_dir,
                                arch_writer)  # temp comment