示例#1
0
def apply_themes(args, device, model):
    img_path = os.path.join("photos", args.content_image)

    content_image = Image.open(img_path).resize(eval_size)

    masks = utils.get_masks(content_image, args.seg_threshold)
    filter_ids = utils.select_filters(masks, content_image, total_filters)

    content_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])

    content_image = content_transform(content_image)
    content_images = content_image.expand(len(filter_ids), -1, -1,
                                          -1).to(device)
    # one forward pass to render themes
    with torch.no_grad():
        if args.load_model is None or args.load_model == "None":
            theme_model = model
        else:
            # our saved models were trained with gpu 3
            #theme_model = torch.load(args.load_model, map_location={'cuda:3': 'cuda:0'})
            theme_model = torch.load(args.load_model)
            theme_model.eval()
        theme_model.to(device)
        output = theme_model(content_images, filter_ids).cpu()

    output_list = []
    for img in output:
        img = img.clone().clamp(0, 255).numpy()
        img = img.transpose(1, 2, 0).astype("uint8")
        output_list.append(img)

    rendered_themes = utils.render_themes(output_list, masks)
    utils.draw(rendered_themes, args.content_image, args.output_image)
示例#2
0
def walk_through_dataset(root_folder, depth, start_from=False, plot_evolution=False):
    generator = utils.walk_level(root_folder, depth)
    gens = [[r, f] for r, d, f in generator if len(r.split("/")) == len(root_folder.split("/")) + depth]

    # This boolean controls whether the algorithm will be used on a specific image or not
    if start_from:
        go = False
    else:
        go = True

    for root, files in gens:
        images = utils.get_images(files, root)
        masks = utils.get_masks(files, root)
        cages = utils.get_cages(files, root)
        if not images or not masks or not cages:
            if not images:
                print root, 'has no .png image'
            if not masks:
                print root, 'has no .png mask'
            if not cages:
                print root, 'has no .txt cages'
        else:
            # TODO: FIX TO ALLOW MORE IMAGES
            for image in images:
                if image.spec_name == start_from:
                    go = True
                if not go:
                    continue
                for mask in masks:
                    for cage in cages:
                        print '\nSegmenting', image.root
                        aux_cage = copy.deepcopy(cage)
                        resulting_cage = cac_segmenter.cac_segmenter(image, mask, aux_cage, None,
                                                                     plot_evolution=plot_evolution)
                        evaluate_results(image, cage, mask, resulting_cage, files, root)
示例#3
0
    def inference_minor_util60(role_id, handcards, num, is_pair, dup_mask, main_cards_char):
        for main_card in main_cards_char:
            handcards.remove(main_card)

        s = get_mask(handcards, action_space, None).astype(np.float32)
        outputs = []
        minor_type = 1 if is_pair else 0
        for i in range(num):
            input_single, input_pair, _, _ = get_masks(handcards, None)
            _, _, _, _, _, _, minor_response_prob = func(
                [np.array([role_id]), s.reshape(1, -1), np.zeros([1, 9085]), np.array([minor_type])]
            )

            # give minor cards
            mask = None
            if is_pair:
                mask = np.concatenate([input_pair, [0, 0]]) * dup_mask
            else:
                mask = input_single * dup_mask

            minor_response = take_action_from_prob(minor_response_prob, mask)
            dup_mask[minor_response] = 0

            # convert network output to char cards
            handcards.remove(to_char(minor_response + 3))
            if is_pair:
                handcards.remove(to_char(minor_response + 3))
            s = get_mask(handcards, action_space, None).astype(np.float32)

            # save to output
            outputs.append(to_char(minor_response + 3))
            if is_pair:
                outputs.append(to_char(minor_response + 3))
        return outputs
示例#4
0
 def _load_all_masks(self, label_all, img_size):
     print('preparing all masks data.....')
     masks = []
     for labels in label_all:
         mask = get_masks(labels, img_size)
         masks.append(mask)
     return masks
示例#5
0
 def get_mask(self):
     if self.act == ACT_TYPE.PASSIVE:
         decision_mask, response_mask, bomb_mask, _ = get_mask_alter(
             self.curr_handcards_char, to_char(self.last_cards_value),
             self.category)
         if self.mode == MODE.PASSIVE_DECISION:
             return decision_mask
         elif self.mode == MODE.PASSIVE_RESPONSE:
             return response_mask
         elif self.mode == MODE.PASSIVE_BOMB:
             return bomb_mask
         elif self.mode == MODE.MINOR_RESPONSE:
             input_single, input_pair, _, _ = get_masks(
                 self.curr_handcards_char, None)
             if self.minor_type == 1:
                 mask = np.append(input_pair, [0, 0])
             else:
                 mask = input_single
             for v in set(self.intention):
                 mask[v - 3] = 0
             return mask
     elif self.act == ACT_TYPE.ACTIVE:
         decision_mask, response_mask, _, length_mask = get_mask_alter(
             self.curr_handcards_char, [], self.category)
         if self.mode == MODE.ACTIVE_DECISION:
             return decision_mask
         elif self.mode == MODE.ACTIVE_RESPONSE:
             return response_mask[self.active_decision]
         elif self.mode == MODE.ACTIVE_SEQ:
             return length_mask[self.active_decision][self.active_response]
         elif self.mode == MODE.MINOR_RESPONSE:
             input_single, input_pair, _, _ = get_masks(
                 self.curr_handcards_char, None)
             if self.minor_type == 1:
                 mask = np.append(input_pair, [0, 0])
             else:
                 mask = input_single
             for v in set(self.intention):
                 mask[v - 3] = 0
             return mask
示例#6
0
    def __getitem__(self, index):
        img = np.array(self.img_all[index], copy=True)
        self.img_size = img.shape[0:2]
        # seg_mask = np.array(self.segmask_all[index], copy=True)
        labels = np.array(self.label_all[index], copy=True)
        labels1 = np.array(self.label_all[index], copy=True)
        masks = get_masks(labels1, self.img_size)

        img_region = np.expand_dims(img, 0)
        masks_region = masks
        img_region = img_region / img_region.max()

        return np.array(img_region), np.array(masks_region), np.array(labels)
示例#7
0
def run(input_path):
    print 'Read input'
    imgs = utils.read_input(input_path)

    print 'Calibrate camera'
    camera_mat, dist_coeffs = utils.calibrate_camera()

    print 'Undistort images'
    imgs_undistorted = utils.undistort_imgs(imgs, camera_mat, dist_coeffs)

    print 'Get masks'
    masks = utils.get_masks(imgs_undistorted)

    print 'Transform perspective'
    masks_birdseye = utils.birdseye(masks)

    print 'Find lines'
    lines, history = utils.find_lines(masks_birdseye)

    print 'Draw lanes'
    imgs_superimposed = utils.draw_lane(imgs_undistorted, lines, history)

    return imgs_superimposed
示例#8
0
# Undistorted calibration image
calibration_img = utils.read_input('../data/camera_cal/calibration1.jpg')
camera_mat, dist_coeffs = utils.calibrate_camera()
cal_undistorted = utils.undistort_imgs(calibration_img, camera_mat,
                                       dist_coeffs)
cal_save_path = join(c.SAVE_DIR, 'undistort_cal.jpg')
utils.save(cal_undistorted, cal_save_path)

# Undistorted road image
road_img = utils.read_input('../data/test_images/test3.jpg')
road_undistorted = utils.undistort_imgs(road_img, camera_mat, dist_coeffs)
road_save_path = join(c.SAVE_DIR, 'undistort_road.jpg')
utils.save(road_undistorted, road_save_path)

# Mask image
mask = utils.get_masks(road_undistorted)
mask_save_path = join(c.SAVE_DIR, 'mask.jpg')
utils.save(mask * 255, mask_save_path)

# Birdseye transform
birdseye = utils.birdseye(mask)
birdseye_save_path = join(c.SAVE_DIR, 'birdseye.jpg')
utils.save(birdseye * 255, birdseye_save_path)

# Fit lines
find_fit = utils.visualize_find_fit(birdseye[0])
find_fit_save_path = join(c.SAVE_DIR, 'find_fit.jpg')
utils.save(find_fit, find_fit_save_path)

# Output
lines, history = utils.find_lines(birdseye)
def evaluate(eval_loader, name='val'):
    # Turn on evaluation mode which disables dropout.
    model.eval()
    start_time = time.time()
    #NUM_Q_SAMPLES = opt.num_q_samples
    NUM_Z_SAMPLES = opt.num_z_samples
    NUM_ITERS = min(opt.num_iters, len(eval_loader))
    MOD_STEP = opt.mod_step  #5 #3
    NUM_MODS = sorted(
        list(
            set([1] +
                [n_mods for n_mods in range(2, num_modalities + 1, MOD_STEP)] +
                [num_modalities])))
    NUM_CONTEXTS = [0, 1, 2, 3, 4, 5]
    NUM_TARGET = 5
    MASK_STEP = opt.mask_step  #10 #5

    all_masks = []
    for n_mods in NUM_MODS:
        masks = get_masks(num_modalities, min_modes=n_mods, max_modes=n_mods)
        masks = list(set(masks[::MASK_STEP] + [masks[-1]]))
        all_masks += masks
    m_indices = dict(
        zip([get_str_from_mask(mask) for mask in all_masks],
            [i for i in range(len(all_masks))]))

    logging('num mods : {}'.format(NUM_MODS), path=opt.path)
    logging('num ctxs : {}'.format(NUM_CONTEXTS), path=opt.path)
    logging('num tgt  : {}'.format(NUM_TARGET), path=opt.path)
    logging('mask step: {}'.format(MASK_STEP), path=opt.path)
    logging('masks    : {}'.format(m_indices), path=opt.path)

    total_avg_batch_sizes_per_nmod_nctx = [
        [0 for i in range(len(NUM_CONTEXTS))] for j in range(num_modalities)
    ]
    total_avg_acc1_per_nmod_nctx = [[0 for i in range(len(NUM_CONTEXTS))]
                                    for j in range(num_modalities)]
    total_avg_acc5_per_nmod_nctx = [[0 for i in range(len(NUM_CONTEXTS))]
                                    for j in range(num_modalities)]
    total_batch_sizes_per_nmod_nctx = [[0 for i in range(len(NUM_CONTEXTS))]
                                       for j in range(len(all_masks))]
    total_acc1_per_nmod_nctx = [[0 for i in range(len(NUM_CONTEXTS))]
                                for j in range(len(all_masks))]
    total_acc5_per_nmod_nctx = [[0 for i in range(len(NUM_CONTEXTS))]
                                for j in range(len(all_masks))]

    with torch.no_grad():
        for batch_idx, (eval_info, eval_context,
                        eval_target) in enumerate(eval_loader):
            # init batch
            eval_context = batch_to_device(eval_context, device)
            eval_target = batch_to_device(eval_target, device)
            eval_all = merge_two_batch(eval_context, eval_target)
            num_episodes = len(eval_context)

            # select target
            _new_eval_target = []
            for target in eval_all:
                _target = tuple([
                    target[i][-NUM_TARGET:] if target[i] is not None else None
                    for i in range(len(target))
                ])
                _new_eval_target += [_target]
            eval_target = _new_eval_target

            # forward
            for n_mods in NUM_MODS:
                masks = get_masks(num_modalities,
                                  min_modes=n_mods,
                                  max_modes=n_mods)
                masks = list(set(masks[::MASK_STEP] + [masks[-1]]))
                for mask in masks:
                    n_mods = sum(mask)
                    avg_m_idx = n_mods - 1
                    m_idx = m_indices[get_str_from_mask(mask)]
                    for c_idx, num_context in enumerate(NUM_CONTEXTS):
                        # select context
                        _new_eval_context = []
                        for context in eval_all:
                            _context = tuple([
                                context[i][:num_context]
                                if context[i] is not None and num_context > 0
                                and mask[i // 2] else None
                                for i in range(len(context))
                            ])
                            _new_eval_context += [_context]
                        eval_context = _new_eval_context

                        # get labels
                        eval_label = torch.Tensor(
                            [i for i in range(num_episodes)]).long().to(device)

                        # infer
                        logprobs_per_batch = []
                        for i_ep in range(num_episodes):
                            new_eval_target = [eval_target[i_ep]
                                               ] * num_episodes

                            # get dim size per episode
                            dim_per_eps = get_dim_size(
                                new_eval_target, is_grayscale=opt.grayscale)

                            # forward
                            logprobs = []
                            for j in range(NUM_Z_SAMPLES):
                                # forward
                                _, _, logprob, info = model.predict(
                                    eval_context,
                                    new_eval_target,
                                    is_grayscale=opt.grayscale,
                                    use_uint8=False)

                                # append to loss_logprobs
                                logprobs += [logprob.unsqueeze(1)]

                            # concat
                            logprobs = torch.cat(logprobs, dim=1)

                            # get logprob
                            logprobs = logprob_logsumexp(logprobs).detach()

                            # get logprob per dimension
                            for i in range(num_episodes):
                                logprobs[i:i + 1] /= float(dim_per_eps[i])

                            # append
                            logprobs_per_batch += [logprobs.unsqueeze(1)]

                        # concat
                        logprobs_per_batch = torch.cat(logprobs_per_batch,
                                                       dim=1)

                        # get acc
                        acc1, acc5 = accuracy(logprobs_per_batch,
                                              eval_label,
                                              topk=(1, 5))
                        cur_acc1 = acc1[0].item()
                        cur_acc5 = acc5[0].item()
                        total_avg_acc1_per_nmod_nctx[avg_m_idx][
                            c_idx] += cur_acc1 * num_episodes
                        total_avg_acc5_per_nmod_nctx[avg_m_idx][
                            c_idx] += cur_acc5 * num_episodes
                        total_avg_batch_sizes_per_nmod_nctx[avg_m_idx][
                            c_idx] += num_episodes
                        total_acc1_per_nmod_nctx[m_idx][
                            c_idx] += cur_acc1 * num_episodes
                        total_acc5_per_nmod_nctx[m_idx][
                            c_idx] += cur_acc5 * num_episodes
                        total_batch_sizes_per_nmod_nctx[m_idx][
                            c_idx] += num_episodes

            # plot
            if (batch_idx + 1) % opt.vis_interval == 0 or (
                    batch_idx + 1) == len(eval_loader):
                elapsed = time.time() - start_time
                start_time = time.time()

                # print
                logging('| {} '
                        '| {:5d}/{:5d} '
                        '| sec/step {:5.2f} '
                        '| acc (top1) {:.3f} '
                        '| acc (top5) {:.3f} '.format(
                            name,
                            batch_idx + 1,
                            len(eval_loader),
                            elapsed / opt.vis_interval,
                            cur_acc1,
                            cur_acc5,
                        ),
                        path=opt.path)

            if (batch_idx + 1) == NUM_ITERS:
                break

    # print
    logging(''.join(
        ['masks V / # of context > '] +
        ['  {:4d}'.format(num_context) for num_context in NUM_CONTEXTS]),
            path=opt.new_path)
    logging('=' * 17 + ' acc1 ' + '=' * 17 + ' | ' + '=' * 17 + ' acc5 ' +
            '=' * 17,
            path=opt.new_path)
    for mask in all_masks:
        mask_str = get_str_from_mask(mask)
        m_idx = m_indices[mask_str]
        txt = ' {} |'.format(mask_str)
        for c_idx, num_context in enumerate(NUM_CONTEXTS):
            total_batch_size = total_batch_sizes_per_nmod_nctx[m_idx][c_idx]
            total_acc1 = total_acc1_per_nmod_nctx[m_idx][
                c_idx] / total_batch_size
            writer.add_scalar('mask{}/{}/acc1'.format(mask_str, name),
                              total_acc1, num_context)
            txt += '  {:3.1f}'.format(total_acc1)
        txt += ' | '
        for c_idx, num_context in enumerate(NUM_CONTEXTS):
            total_batch_size = total_batch_sizes_per_nmod_nctx[m_idx][c_idx]
            total_acc5 = total_acc5_per_nmod_nctx[m_idx][
                c_idx] / total_batch_size
            writer.add_scalar('mask{}/{}/acc5'.format(mask_str, name),
                              total_acc5, num_context)
            txt += '  {:3.1f}'.format(total_acc5)
        logging(txt, path=opt.new_path)

    # print
    logging('', path=opt.new_path)
    logging('', path=opt.new_path)
    logging(''.join(
        ['# of mods V / # of context > '] +
        ['  {:4d}'.format(num_context) for num_context in NUM_CONTEXTS]),
            path=opt.new_path)
    logging('=' * 17 + ' acc1 ' + '=' * 17 + ' | ' + '=' * 17 + ' acc5 ' +
            '=' * 17,
            path=opt.new_path)
    for n_mods in NUM_MODS:
        avg_m_idx = n_mods - 1
        txt = ' {} |'.format(n_mods)
        for c_idx, num_context in enumerate(NUM_CONTEXTS):
            total_avg_batch_size = total_avg_batch_sizes_per_nmod_nctx[
                avg_m_idx][c_idx]
            total_avg_acc1 = total_avg_acc1_per_nmod_nctx[avg_m_idx][
                c_idx] / total_avg_batch_size
            writer.add_scalar('M{}/{}/acc1'.format(n_mods, name),
                              total_avg_acc1, num_context)
            writer.add_scalar('C{}/{}/acc1'.format(num_context, name),
                              total_avg_acc1, n_mods)
            txt += '  {:3.1f}'.format(total_avg_acc1)
        txt += ' | '
        for c_idx, num_context in enumerate(NUM_CONTEXTS):
            total_avg_batch_size = total_avg_batch_sizes_per_nmod_nctx[
                avg_m_idx][c_idx]
            total_avg_acc5 = total_avg_acc5_per_nmod_nctx[avg_m_idx][
                c_idx] / total_avg_batch_size
            writer.add_scalar('M{}/{}/acc5'.format(n_mods, name),
                              total_avg_acc5, num_context)
            writer.add_scalar('C{}/{}/acc5'.format(num_context, name),
                              total_avg_acc5, n_mods)
            txt += '  {:3.1f}'.format(total_avg_acc5)
        logging(txt, path=opt.new_path)

    return total_acc1 / total_batch_size, total_acc5 / total_batch_size
def evaluate(eval_loader, test=False):
    # Turn on evaluation mode which disables dropout.
    name = 'test' if test else 'val'
    model.eval()
    transform = get_transform()
    NUM_ITERS = min(opt.num_iters, len(eval_loader))
    NUM_TEST = 5
    NUM_CONTEXTS = sorted([nc for nc in opt.n_context])
    if len(NUM_CONTEXTS) == 0:
        NUM_CONTEXTS = [0, 1, 5, 10]  #[0, 1, 5, 10, 15]
    assert (NUM_TEST + NUM_CONTEXTS[-1]) <= dataset_info['nviews']
    NUM_MODS = sorted([n_mods for n_mods in opt.n_mods])
    if len(NUM_MODS) == 0:
        NUM_MODS = [n_mod for n_mod in range(1, num_modalities + 1)]
    assert NUM_MODS[-1] <= num_modalities
    assert NUM_MODS[0] > 0
    all_masks = []
    for n_mods in NUM_MODS:
        masks = get_masks(num_modalities, min_modes=n_mods, max_modes=n_mods)
        all_masks += masks

    logging('num mods : {}'.format(NUM_MODS), path=opt.path)
    logging('num ctxs : {}'.format(NUM_CONTEXTS), path=opt.path)
    logging('num tgt  : {}'.format(NUM_TEST), path=opt.path)
    logging('masks    : {}'.format(all_masks), path=opt.path)

    hpt_tgt_gen = {}
    avg_diffs = {}
    num_datas = {}
    with torch.no_grad():
        for i_sample in range(1, opt.num_samples + 1):
            did_plot = [False] * num_classes
            for batch_idx, (eval_info, eval_context,
                            eval_target) in enumerate(eval_loader):
                # init batch
                eval_context = batch_to_device(eval_context, device)
                eval_target = batch_to_device(eval_target, device)
                eval_all = merge_two_batch(eval_context, eval_target)
                num_episodes = len(eval_context)

                # get img_queries
                img_queries = torch.from_numpy(
                    np.array(eval_info[0]['add_cameras'])).float()

                # get true_images and hand_images
                true_images = load_images(eval_info[0]['add_images'],
                                          transform)
                _true_images = get_grid_image(true_images,
                                              16,
                                              3,
                                              64,
                                              64,
                                              nrow=4,
                                              pad_value=0)
                hand_images = load_images(eval_info[0]['hand_images'],
                                          transform)
                _hand_images = get_grid_image(hand_images,
                                              15,
                                              3,
                                              64,
                                              64,
                                              nrow=15,
                                              pad_value=0)
                _data_images = []
                for idx, (nchannels, nheight, nwidth, _,
                          mtype) in enumerate(dataset_info['dims']):
                    if mtype == 'image':
                        _data_images += [eval_all[0][idx * 2]]
                _data_images = get_combined_visualization_image_data(
                    opt.dataset,
                    dataset_info['dims'],
                    _data_images,
                    dataset_info['nviews'],
                    min(4, num_episodes),
                    nrow=15,
                    pad_value=0)[0]
                ''' temporary '''
                assert len(eval_context) == 1
                assert len(eval_target) == 1
                cls = eval_info[0]['class']
                ''' per class '''
                # visualize per class
                if not did_plot[cls]:
                    # change flag
                    did_plot[cls] = True

                    # draw true_images and hand_images
                    writer.add_image(
                        '{}/gt-img-cls{}-i{}'.format(name, cls, i_sample),
                        _true_images, 0)
                    writer.add_image(
                        '{}/hand-img-cls{}-i{}'.format(name, cls, i_sample),
                        _hand_images, 0)
                    writer.add_image(
                        '{}/data-img-cls{}-i{}'.format(name, cls, i_sample),
                        _data_images, 0)
                    for num_context in NUM_CONTEXTS:
                        _hand_images = get_grid_image(
                            hand_images[:num_context],
                            num_context,
                            3,
                            64,
                            64,
                            nrow=5,
                            pad_value=0)
                        writer.add_image(
                            '{}/ctx-hand-img-cls{}-i{}-nc{}'.format(
                                name, cls, i_sample, num_context),
                            _hand_images, 0)

                    def draw_img_gen(mask, num_context=0):
                        # get mask index
                        m_idx = sum(mask) - 1

                        # get context
                        new_eval_context, new_eval_target = trim_context_target(
                            eval_all,
                            num_context=num_context,
                            mask=mask,
                            num_modalities=num_modalities)
                        if sum([
                                int(new_eval_target[0][j * 2] is None)
                                for j in range(num_modalities)
                        ]) == num_modalities:
                            return

                        # select test
                        _new_eval_target = []
                        for i in range(num_episodes):
                            _target = []
                            for idx, (nchannels, nheight, nwidth, _,
                                      mtype) in enumerate(
                                          dataset_info['dims']):
                                data, query = new_eval_target[i][
                                    idx * 2], new_eval_target[i][idx * 2 + 1]
                                if mtype == 'haptic':
                                    _target += [
                                        data[-NUM_TEST:]
                                        if data is not None else None
                                    ]
                                    _target += [
                                        query[-NUM_TEST:]
                                        if data is not None else None
                                    ]
                                else:
                                    _target += [data]
                                    _target += [query]
                            _new_eval_target += [tuple(_target)]
                        new_eval_target = _new_eval_target

                        # get batch size
                        batch_size, mod_batch_sizes = get_batch_size(
                            new_eval_target)

                        # get queries
                        mod_queries, num_mod_queries = get_queries(
                            new_eval_target,
                            device,
                            num_hpt_queries=NUM_TEST,
                            img_queries=img_queries)

                        # forward
                        outputs, _, _, _ = model(new_eval_context,
                                                 new_eval_target,
                                                 is_grayscale=opt.grayscale)

                        # generate
                        gens, _ = model.generate(new_eval_context,
                                                 tuple(mod_queries),
                                                 is_grayscale=opt.grayscale)

                        # visualize
                        img_ctxs, img_tgts, img_outputs, img_gens = [], [], [], []
                        hpt_ctxs, hpt_tgts, hpt_outputs, hpt_gens = [], [], [], []
                        for idx, (nchannels, nheight, nwidth, _,
                                  mtype) in enumerate(dataset_info['dims']):
                            # get output and gen
                            output = outputs[idx]
                            gen = gens[idx]
                            _num_mod_queries = num_mod_queries[idx]

                            # visualize
                            if mtype == 'image':
                                # grayscale
                                if opt.grayscale:
                                    if output.size(0) > 0:
                                        output = output.expand(
                                            output.size(0), nchannels, nheight,
                                            nwidth)
                                    gen = gen.expand(gen.size(0), nchannels,
                                                     nheight, nwidth)

                                # get ctx, tgt
                                if num_context > 0 and mask[idx]:
                                    sz = new_eval_context[0][idx *
                                                             2].size()[1:]
                                    ctx = torch.cat([
                                        new_eval_context[0][idx * 2],
                                        gen.new_zeros(
                                            dataset_info['nviews'] -
                                            num_context, *sz)
                                    ],
                                                    dim=0)
                                    num_target = new_eval_target[0][
                                        idx * 2].size(0) if new_eval_target[0][
                                            idx * 2] is not None else 0
                                    assert num_target == output.size(0)
                                    if num_target > 0:
                                        tgt = torch.cat([
                                            gen.new_zeros(
                                                dataset_info['nviews'] -
                                                num_target, *sz),
                                            new_eval_target[0][idx * 2],
                                        ],
                                                        dim=0)
                                        output = torch.cat([
                                            gen.new_zeros(
                                                dataset_info['nviews'] -
                                                num_target, *sz),
                                            output,
                                        ],
                                                           dim=0)
                                    else:
                                        tgt = gen.new_zeros(
                                            dataset_info['nviews'] *
                                            num_episodes, *sz)
                                        output = gen.new_zeros(
                                            dataset_info['nviews'] *
                                            num_episodes, *sz)
                                else:
                                    ctx = gen.new_zeros(
                                        dataset_info['nviews'] * num_episodes,
                                        nchannels, nheight, nwidth)
                                    tgt = new_eval_target[0][idx * 2]

                                # append to list
                                img_gens += [gen]
                                img_outputs += [output]
                                img_ctxs += [ctx]
                                img_tgts += [tgt]
                                num_img_queries = _num_mod_queries
                            elif mtype == 'haptic':
                                ctx = new_eval_context[0][idx * 2]
                                tgt = new_eval_target[0][idx * 2]

                                # append to list
                                hpt_gens += [gen]
                                hpt_outputs += [output]
                                hpt_ctxs += [ctx]
                                hpt_tgts += [tgt]
                                num_hpt_queries = _num_mod_queries
                            else:
                                raise NotImplementedError

                        # combine haptic
                        if not get_str_from_mask(mask) in hpt_tgt_gen:
                            hpt_tgt_gen[get_str_from_mask(mask)] = {}
                            avg_diffs[get_str_from_mask(mask)] = np.zeros(
                                len(NUM_CONTEXTS))
                            num_datas[get_str_from_mask(mask)] = 0
                        hpt_tgts = torch.cat(hpt_tgts, dim=1)
                        hpt_gens = torch.cat(hpt_gens, dim=1)
                        hpt_tgt_gen[get_str_from_mask(mask)][num_context] = (
                            hpt_tgts, hpt_gens)

                        # visualize combined image
                        xgs = get_combined_visualization_image_data(
                            opt.dataset,
                            dataset_info['dims'],
                            img_gens,
                            num_img_queries,
                            min(4, num_episodes),
                            nrow=4,
                            pad_value=0)
                        xos = get_combined_visualization_image_data(
                            opt.dataset,
                            dataset_info['dims'],
                            img_outputs,
                            dataset_info['nviews'],
                            min(4, num_episodes),
                            nrow=4,
                            pad_value=0)
                        xcs = get_combined_visualization_image_data(
                            opt.dataset,
                            dataset_info['dims'],
                            img_ctxs,
                            dataset_info['nviews'],
                            min(4, num_episodes),
                            nrow=4,
                            pad_value=0)
                        _xcs = get_combined_visualization_image_data(
                            opt.dataset,
                            dataset_info['dims'],
                            img_ctxs,
                            num_context,
                            min(4, num_episodes),
                            nrow=5,
                            pad_value=0)
                        xts = get_combined_visualization_image_data(
                            opt.dataset,
                            dataset_info['dims'],
                            img_tgts,
                            dataset_info['nviews'],
                            min(4, num_episodes),
                            nrow=4,
                            pad_value=0)
                        for i, (xc, xt, xo,
                                xg) in enumerate(zip(xcs, xts, xos, xgs)):
                            writer.add_image(
                                '{}/ctx-cls{}-M{}-mask{}-i{}-nc{}/img'.format(
                                    name, cls, m_idx + 1,
                                    get_str_from_mask(mask), i, num_context),
                                _xcs[i], 0)
                            writer.add_image(
                                '{}/gen-cls{}-M{}-mask{}-i{}-nc{}/img'.format(
                                    name, cls, m_idx + 1,
                                    get_str_from_mask(mask), i, num_context),
                                xg, 0)
                            x = torch.cat([xc, xt, xo, xg], dim=2)
                            writer.add_image(
                                '{}/ctx-tgt-rec-gen-cls{}-M{}-mask{}-i{}-nc{}/img'
                                .format(name, cls, m_idx + 1,
                                        get_str_from_mask(mask), i,
                                        num_context), x, 0)

                    # run vis
                    for mask in all_masks:
                        for num_context in NUM_CONTEXTS:
                            draw_img_gen(mask=mask, num_context=num_context)

                        # visualize combined haptic
                        m_idx = sum(mask) - 1  # get mask index
                        xs, diffs = get_combined_visualization_haptic_data(
                            hpt_tgt_gen[get_str_from_mask(mask)],
                            title='mask: {}'.format(get_str_from_mask(mask)))
                        _diffs = [
                            np.mean([diffs[i][j] for i in range(len(diffs))])
                            for j in range(len(NUM_CONTEXTS))
                        ]
                        num_datas[get_str_from_mask(mask)] += 1
                        for j, diff in enumerate(_diffs):
                            avg_diffs[get_str_from_mask(mask)][j:j + 1] += diff
                            num_context = NUM_CONTEXTS[j]
                            writer.add_scalar(
                                '{}/diff-cls{}-M{}-mask{}-all/hpt'.format(
                                    name, cls, m_idx + 1,
                                    get_str_from_mask(mask)), diff,
                                num_context)

                        for i, x in enumerate(xs):
                            writer.add_image(
                                '{}/tgt-gen-cls{}-M{}-mask{}-i{}/hpt'.format(
                                    name, cls, m_idx + 1,
                                    get_str_from_mask(mask), i),
                                convert_npimage_torchimage(x), 0)

                            for j, diff in enumerate(diffs[i]):
                                num_context = NUM_CONTEXTS[j]
                                writer.add_scalar(
                                    '{}/diff-cls{}-M{}-mask{}-i{}/hpt'.format(
                                        name, cls, m_idx + 1,
                                        get_str_from_mask(mask), i), diff,
                                    num_context)

                if (batch_idx + 1) % 1 == 0:
                    print(batch_idx + 1, '/', NUM_ITERS, ' [',
                          len(eval_loader), ']')
                if (batch_idx + 1) == NUM_ITERS:
                    break

    return
示例#11
0
    args = parser.parse_args()

    if args.dataset_folder:
        X = pd.read_csv(args.dataset_folder + '/X.csv')
        y = pd.read_csv(args.dataset_folder + '/y.csv')
        networkx_graph = nx.read_graphml(args.dataset_folder + '/graph.graphml')
    else:
        X = pd.read_csv(args.X)
        y = pd.read_csv(args.y)
        networkx_graph = nx.read_graphml(args.graph)

    networkx_graph = nx.relabel_nodes(networkx_graph, {str(i): i for i in range(len(networkx_graph))})
    cat_features = args.cat_features

    train_mask, val_mask, train_val_mask, test_mask = get_masks(X.shape[0],
                                                                train_size=args.train_size,
                                                                val_size=args.val_size,
                                                                random_seed=args.random_seed, )

    os.makedirs('losses/', exist_ok=True)

    if args.model.lower() == 'gbdt':
        from models.GBDT import GBDT
        model = GBDT(depth=args.depth)
        model.fit(X, y, train_mask, val_mask, test_mask,
                 cat_features=cat_features, num_epochs=args.num_features, patience=args.patience,
                 learning_rate=args.learning_rate, plot=False, verbose=False,
                 loss_fn=args.loss_fn)

    elif args.model.lower() == 'mlp':
        from models.MLP import MLP
        model = MLP(task=args.task)
示例#12
0
    def learn(self, batch_size=64, agent=0):
        # DDPG implementation
        if agent == 0:
            experiences = self.memory_0.sample(batch_size)
        if agent == 1:
            experiences = self.memory_1.sample(batch_size)
        states, actions, rewards, next_states, dones, batch_lengths = experiences
        # Because different batch data is variable length, mask is used for cover the unused part.
        masks = get_masks(batch_lengths)

        # update critic
        with torch.no_grad():
            if agent == 0:
                next_actions = torch.cat([
                    self.actor_target_0(next_states[:, :, :24]), actions[:, :,
                                                                         2:]
                ],
                                         dim=2)
            if agent == 1:
                next_actions = torch.cat([
                    actions[:, :, :2],
                    self.actor_target_1(next_states[:, :, 24:])
                ],
                                         dim=2)
            q_value_prime = self.critic_target(next_states, next_actions,
                                               agent).squeeze(2) * masks
        q_value = self.critic_local(states.requires_grad_(True), actions,
                                    agent).squeeze(2) * masks
        td_error = (rewards + self.gamma * q_value_prime *
                    (1 - dones) - q_value) * masks  # one-step estimate

        critic_loss = ((td_error**2).sum(1) / batch_lengths).mean()
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.critic_local.parameters(), 1)
        self.critic_optimizer.step()
        self.soft_update(self.critic_local, self.critic_target, tau=0.01)

        # update actor
        if agent == 0:
            actions_pred = torch.cat(
                [self.actor_local_0(states[:, :, :24]), actions[:, :, 2:]],
                dim=2)
        if agent == 1:
            actions_pred = torch.cat(
                [actions[:, :, :2],
                 self.actor_local_1(states[:, :, 24:])],
                dim=2)

        q_value = self.critic_local(states, actions_pred, agent)
        actor_loss = -(
            (q_value.squeeze() * masks).sum(1) / batch_lengths).mean()

        if agent == 0:
            self.actor_optimizer_0.zero_grad()
            (actor_loss).backward()
            torch.nn.utils.clip_grad_norm_(self.actor_local_0.parameters(), 1)
            self.actor_optimizer_0.step()
            self.soft_update(self.actor_local_0, self.actor_target_0, tau=0.01)
        if agent == 1:
            self.actor_optimizer_1.zero_grad()
            (actor_loss).backward()
            torch.nn.utils.clip_grad_norm_(self.actor_local_1.parameters(), 1)
            self.actor_optimizer_1.step()
            self.soft_update(self.actor_local_1, self.actor_target_1, tau=0.01)

        self.noise.reset()

        if self.t_step % 128 == 0:
            print(f'q_value: {q_value.detach().mean().item():.5f}, '
                  f'critic_loss: {critic_loss.detach().mean().item():.5f}, '
                  f'actor_loss: {actor_loss.detach().mean().item():.5f} ')