def test_safe_to_add_to_group_wrong_extension(self):
        instrument = "MUSR"
        extension = "MA"

        ws = create_workspace(instrument + "test" + "FD")
        tmp = create_workspace("dummy")
        make_group([tmp], "group")
        # get the group
        group = retrieve_ws("group")

        self.assertEqual(
            safe_to_add_to_group(ws, instrument, [group], extension), False)
    def test_check_not_in_group(self):
        ws_in_group = create_workspace("in")
        ws_out_group = create_workspace("out")
        make_group([ws_in_group], "group")
        # get the group
        group = retrieve_ws("group")

        self.assertEqual(len(group.getNames()), 1)
        self.assertEqual(group.getNames()[0], "in")
        self.assertEqual(check_not_in_group([group], ws_in_group.name()),
                         False)
        self.assertEqual(check_not_in_group([group], ws_out_group.name()),
                         True)
    def test_add_list_to_group(self):
        ws = create_workspace("unit")
        ws2 = create_workspace("test")
        tmp = create_workspace("dummy")
        make_group([tmp], "group")
        # get the group
        group = retrieve_ws("group")

        self.assertEqual(len(group.getNames()), 1)
        self.assertEqual(group.getNames()[0], "dummy")

        add_list_to_group([ws.name(), ws2.name()], group)
        expected = ["dummy", "unit", "test"]

        self.assertEqual(len(group.getNames()), len(expected))
        for name in group.getNames():
            self.assertTrue(name in expected)
            expected.remove(name)
    def test_add_to_group_ignore_if_already_in_group(self):
        instrument = "MUSR"
        extension = "MA"
        run = "62260"
        ws = create_workspace(instrument + run + "fwd" + extension)
        ws2 = create_workspace(instrument + run + "bwd" + extension)
        _ = create_workspace("EMU" + run + "fwd" + extension)
        _ = create_workspace(instrument + run + "fwd" + "FD")
        # there was a bug that meant tables didnt work
        table_name = create_table_workspace(instrument + run + "table" +
                                            extension)

        make_group([ws2], "group")
        add_to_group(instrument, extension)

        group = retrieve_ws(instrument + run)
        expected = [ws.name(), table_name]

        self.assertEqual(len(group.getNames()), len(expected))
        for name in group.getNames():
            self.assertTrue(name in expected)
            expected.remove(name)
def add_to_group(instrument, extension):
    str_names = AnalysisDataService.getObjectNames()
    # only group things for the current instrument
    ws_list =  AnalysisDataService.retrieveWorkspaces(str_names)

    #just the groups
    groups = [ws for ws in ws_list if ws.isGroup()]

    # just the workspaces
    def string_name(ws):
        if isinstance(ws, str):
            return ws
        return ws.name()

    names = [string_name(ws) for ws in ws_list if safe_to_add_to_group(ws, instrument, groups, extension)]
    # make sure we include the groups that we already have in the ADS
    group_names = {key.name():[] for key in groups}
    # put ws into groups
    for name in names:
        run = get_run_number_from_workspace_name(name, instrument)
        tmp = instrument+run
        # check the names are not already group workspaces
        if tmp in list(group_names.keys()):
            group_names[tmp] += [name]
        else:
            group_names[tmp] = [name]

    # add to the groups that already exist
    for group in groups:
        if group.name() in group_names.keys():
            add_list_to_group(group_names[group.name()], group)

    # create new groups
    for group in group_names.keys():
        if group not in [group.name() for group in groups] :
            make_group(group_names[group], group)