示例#1
0
def match_smiles_lists(pred_list, target_list, beam_size, should_print=True):
    n_data = 0
    n_matched = np.zeros(beam_size)  # Count of matched smiles
    n_invalid = np.zeros(beam_size)  # Count of invalid smiles
    n_repeat = np.zeros(beam_size)  # Count of repeated predictions
    #
    # with open('template/rare_indices.txt', 'r+') as r_file:
    #     rare_rxn_list = json.load(r_file)

    for data_idx, target_smiles in enumerate(tqdm.tqdm(target_list)):
        # if data_idx not in rare_rxn_list:
        #     continue
        n_data += 1
        target_set = set(data_utils.canonicalize(smiles_list=target_smiles.split('.')))

        pred_beam = pred_list[data_idx]

        beam_matched = False
        prev_sets = []
        for beam_idx, pred_smiles in enumerate(pred_beam):
            pred_set = set(data_utils.canonicalize(smiles_list=pred_smiles.split('.')))
            if '' in pred_set:
                pred_set.remove('')
            set_matched = match_smiles_set(pred_set, target_set)

            # Check if current pred_set matches any previous sets
            for prev_set in prev_sets:
                if match_smiles_set(pred_set, prev_set):
                    n_repeat[beam_idx] += 1

            if len(pred_set) > 0:
                # Add pred set to list of predictions for current example
                prev_sets.append(pred_set)
            else:
                # If the pred set is empty and the string is not, then invalid
                if pred_smiles != '':
                    n_invalid[beam_idx] += 1

            # Increment if not yet matched beam and the pred set matches
            if set_matched and not beam_matched:
                n_matched[beam_idx] += 1
                beam_matched = True

    if should_print:
        print('total examples: %d' % n_data)
        for beam_idx in range(beam_size):
            match_perc = np.sum(n_matched[:beam_idx+1]) / n_data
            invalid_perc = n_invalid[beam_idx] / n_data
            repeat_perc = n_repeat[beam_idx] / n_data

            print('beam: %d, matched: %.3f, invalid: %.3f, repeat: %.3f' %
                (beam_idx+1, match_perc, invalid_perc, repeat_perc))

    return n_data, n_matched, n_invalid, n_repeat
示例#2
0
    def infer(self, lr_data):

        lr_data = data_utils.canonicalize(
            lr_data)  # to torch.FloatTensor  thwc

        print(lr_data.size())
        _, h, w, _ = lr_data.size()

        lr_yuv = data_utils.rgb2yCbCr(lr_data)
        lr_yuv = lr_yuv.permute(0, 3, 1, 2)  # thwc

        lr_y = lr_yuv[:, 0:1, :, :]
        lr_u = lr_yuv[:, 1:2, :, :]
        lr_v = lr_yuv[:, 2:3, :, :]

        # dual direct temporal padding
        lr_y_seq, n_pad_front = self.pad_sequence(lr_y)

        # infer
        hr_y_seq = self.net_G.infer_sequence(lr_y_seq, self.device)
        hr_u_seq = tvs.resize(lr_u, [self.scale * h, self.scale * w],
                              interpolation=3)  # bilinear:2(default) bicubic:3
        hr_v_seq = tvs.resize(lr_v, [self.scale * h, self.scale * w],
                              interpolation=3)  # bilinear:2(default) bicubic:3

        hr_yuv = torch.cat((hr_y_seq, hr_u_seq, hr_v_seq), dim=1)
        hr_yuv = hr_yuv.permute(0, 2, 3, 1)  # tchw

        hr_rgb = data_utils.yCbCr2rgb(hr_yuv).numpy()
        hr_seq = data_utils.float32_to_uint8(hr_rgb)  # thwc|rgb|uint8

        return hr_seq
示例#3
0
    def infer(self, lr_data):

        lr_data = data_utils.canonicalize(lr_data)  # to torch.FloatTensor
        lr_data = lr_data.permute(0, 3, 1, 2)  # tchw

        # dual direct temporal padding
        lr_data, n_pad_front = self.pad_sequence(lr_data)

        # infer
        hr_seq = self.net_G.infer_sequence(lr_data, self.device)

        return hr_seq
示例#4
0
    def infer(self, lr_data):
        """ Function of inference

            Parameters:
                :param lr_data: a rgb video sequence with shape thwc
                :return: a rgb video sequence with type np.uint8 and shape thwc
        """

        # canonicalize
        lr_data = data_utils.canonicalize(lr_data)  # to torch.FloatTensor
        lr_data = lr_data.permute(0, 3, 1, 2)  # tchw

        # temporal padding
        lr_data, n_pad_front = self.pad_sequence(lr_data)

        # infer
        hr_seq = self.net_G.infer_sequence(lr_data, self.device)
        hr_seq = hr_seq[n_pad_front:, ...]
        return hr_seq
示例#5
0
def match_smiles_lists(pred_list,
                       target_list,
                       beam_size,
                       args,
                       should_print=True):
    n_data = 0
    n_matched = np.zeros(beam_size)  # Count of matched smiles
    n_invalid = np.zeros(beam_size)  # Count of invalid smiles
    n_repeat = np.zeros(beam_size)  # Count of repeated predictions

    result_path = args.result_path
    seed = args.seed
    #
    # with open('template/rare_indices.txt', 'r+') as r_file:
    #     rare_rxn_list = json.load(r_file)

    for data_idx, target_smiles in enumerate(tqdm.tqdm(target_list)):
        # if data_idx not in rare_rxn_list:
        #     continue
        n_data += 1
        target_set = set(
            data_utils.canonicalize(smiles_list=target_smiles.split('.')))

        pred_beam = pred_list[data_idx]

        beam_matched = False
        prev_sets = []
        num_repeat = 0
        num_invalid = 0
        for beam_idx, pred_smiles in enumerate(pred_beam):
            cnt_flag = False
            pred_set = set(
                data_utils.canonicalize(smiles_list=pred_smiles.split('.')))
            if '' in pred_set:
                pred_set.remove('')
            set_matched = match_smiles_set(pred_set, target_set)

            # Check if current pred_set matches any previous sets
            for cnt, prev_set in enumerate(prev_sets):
                if match_smiles_set(pred_set, prev_set):
                    n_repeat[beam_idx] += 1
                    if not cnt_flag:
                        num_repeat += 1
                        cnt_flag = True

            if len(pred_set) > 0:
                # Add pred set to list of predictions for current example
                prev_sets.append(pred_set)
            else:
                # If the pred set is empty and the string is not, then invalid
                if pred_smiles != '':
                    n_invalid[beam_idx] += 1
                    num_invalid += 1

            # Increment if not yet matched beam and the pred set matches
            if set_matched and not beam_matched:
                n_matched[beam_idx - num_invalid - num_repeat] += 1
                beam_matched = True

    if should_print:
        print('total examples: %d' % n_data)
        result_path_prefix = 'experiments/results'
        if not os.path.isdir(result_path_prefix):
            os.mkdir(result_path_prefix)
        if not os.path.isdir(result_path):
            os.mkdir(result_path)

        f = open(result_path + str(seed) + '.csv',
                 'w',
                 encoding='utf-8',
                 newline='')
        wr = csv.writer(f)

        for beam_idx in range(beam_size):
            match_perc = np.sum(n_matched[:beam_idx + 1]) / n_data
            invalid_perc = n_invalid[beam_idx] / n_data
            repeat_perc = n_repeat[beam_idx] / n_data
            wr.writerow([match_perc, invalid_perc, repeat_perc])

            print('beam: %d, matched: %.3f, invalid: %.3f, repeat: %.3f' %
                  (beam_idx + 1, match_perc, invalid_perc, repeat_perc))
        f.close()

    return n_data, n_matched, n_invalid, n_repeat
示例#6
0
def combine_latent(input_dir,
                   n_latent,
                   beam_size,
                   output_path=None,
                   clean=False,
                   cross=False,
                   cross_ensem=False,
                   alternative=False):
    """
    Reads the output smiles from each of the latent classes and combines them.
    Args:
        input_dir: The path to the input directory containing output files
        n_latent: The number of latent classes used for the model
        beam_size: Number of smiles results per reaction
        output_path: If given, writes the combined smiles to this path
    """
    # results_path is the prefix for the different latent file outputs
    latent_list = []

    def parse(line):
        c_line = line.strip().replace(' ', '')
        smiles, score = c_line.split(',')
        score = float(score)
        return (smiles, score)

    for latent_idx in range(n_latent):
        file_path = '%s/output_%d' % (input_dir, latent_idx)
        smiles_list = data_utils.read_file(file_path,
                                           beam_size=beam_size,
                                           parse_func=parse)

        latent_list.append(smiles_list)

    combined_list = []

    if output_path is not None:
        output_file = open(output_path, 'w+')

    n_data = len(latent_list[0])
    for data_idx in tqdm.tqdm(range(n_data)):
        r_dict = {}
        for latent_idx in range(n_latent):
            output_list = latent_list[latent_idx][data_idx]
            for smiles, score in output_list:

                if clean:
                    smiles = data_utils.canonicalize(smiles)
                    if smiles == '':
                        continue

                if smiles not in r_dict:  # Add the output to dictionary
                    r_dict[smiles] = (score, latent_idx)
                else:
                    if score > r_dict[smiles][
                            0]:  # Update with the best score if applicable
                        r_dict[smiles] = (score, latent_idx)
        sorted_output = sorted(r_dict.items(),
                               key=operator.itemgetter(1),
                               reverse=True)
        top_smiles = []
        for beam_idx in range(beam_size):
            if beam_idx < len(sorted_output):
                smiles, (score, latent_idx) = sorted_output[beam_idx]
                top_smiles.append(smiles)

                if output_path is not None:
                    output_file.write('%s,%.4f,%d\n' %
                                      (smiles, score, latent_idx))

        combined_list.append(top_smiles)
    if output_path is not None:
        output_file.close()
    return combined_list