コード例 #1
0
ファイル: simulation.py プロジェクト: leeseungho90/liam2
def show_top_times(what, times):
    count = len(times)
    print("top %d %s:" % (count, what))
    for name, timing in times:
        print(" - %s: %s" % (name, time2str(timing)))
    print("total for top %d %s:" % (count, what), end=' ')
    print(time2str(sum(timing for name, timing in times)))
コード例 #2
0
ファイル: team.py プロジェクト: marwahaha/code
 def publishEvent(self, event):
     trace('publishEvent %s', event.eventString)
     if self.publish:
         self.pub_team_event.publish(event)
     # also print to stdout so we can watch it on coach, as some sort of error log
     print time2str(event.timeStamp) + " - r" + str(event.robotId) + " - " + event.eventString
     time.sleep(0.01) # workaround for superfast consecutive publishing, prevent ROS dropping messages
コード例 #3
0
def show_top_processes(process_time, num_processes):
    process_times = sorted(process_time.iteritems(),
                           key=operator.itemgetter(1),
                           reverse=True)
    print "top %d processes:" % num_processes
    for name, p_time in process_times[:num_processes]:
        print " - %s: %s" % (name, time2str(p_time))
    print "total for top %d processes:" % num_processes,
    print time2str(sum(p_time for name, p_time
                       in process_times[:num_processes]))
コード例 #4
0
ファイル: render.py プロジェクト: hughplay/TVR
    def summary(self):
        self.logger.info('---')
        self.logger.info('Data dir: {}'.format(self.output_dir.resolve()))
        if self.config.output_log:
            self.logger.info('Log dir: {}'.format(self.log_dir.resolve()))

        self.logger.info('---')
        state = self.sample_manager.state()
        self.logger.info('Progress: {} (load: {})'.format(
            state['sample_success'] + state['sample_load'],
            state['sample_load']))
        self.logger.info('Sample success: {} ({:.4f})'.format(
            state['sample_success'], state['rate_sample_success']))
        self.logger.info('Stage success: {} ({:.4f})'.format(
            state['stage_success'], state['rate_stage_success']))
        self.logger.info('Total time: {}'.format(utils.time2str(
            state['time'])))
        self.logger.info('Average sample time: {}'.format(
            utils.time2str(state['time_avg_sample'])))
        self.logger.info('Average stage time: {}'.format(
            utils.time2str(state['time_avg_stage'])))

        self.logger.info('---')
        self.logger.info('Initial Visible Object:')
        for init_vis, num in sorted(self.sample_manager.n_init_vis.items()):
            self.logger.info('- {}: {}'.format(init_vis, num))

        self.logger.info('---')
        self.logger.info('Steps:')
        for key, num in sorted(self.sample_manager.n_step.items()):
            self.logger.info('- {}: {}'.format(key, num))

        self.logger.info('---')
        self.logger.info('Object:')
        for key, num in sorted(self.sample_manager.n_obj.items()):
            self.logger.info('- {}: {}'.format(key, num))

        self.logger.info('---')
        self.logger.info('Pair:')
        for key, num in sorted(self.sample_manager.n_pair['gram_1'].items()):
            self.logger.info('- {}: {}'.format(key, num))

        self.logger.info('---')
        self.logger.info('Move Type:')
        for key, num in sorted(self.sample_manager.n_move_type.items()):
            self.logger.info('- {}: {}'.format(key, num))

        self.logger.info('---')
        self.logger.info('Balance State')
        for key, state in sorted(self.sample_manager.balance_state.items()):
            self.logger.info('# {}'.format(key))
            for k, v in sorted(state.items()):
                self.logger.info('- {}: {}'.format(k, v))
コード例 #5
0
def shrinkids(input_path, output_path, toshrink):
    input_file = tables.open_file(input_path)
    output_file = tables.open_file(output_path, mode="w")
    input_entities = input_file.root.entities
    print(" * indexing tables")
    idmaps = {}
    for ent_name, fields in toshrink.iteritems():
        print("    -", ent_name, "...", end=' ')
        start_time = time.time()
        idmaps[ent_name] = index_table(getattr(input_entities, ent_name))
        print("done (%s elapsed)." % time2str(time.time() - start_time))

    fields_to_change = {ent_name: {'id': idmaps[ent_name]}
                        for ent_name in toshrink}
    for ent_name, fields in toshrink.iteritems():
        # fields_to_change[ent_name] = d = []
        for fname in fields:
            if '.' in fname:
                source_ent, fname = fname.split('.')
            else:
                source_ent = ent_name
            fields_to_change[source_ent][fname] = idmaps[ent_name]
    print(" * shrinking ids")
    map_file(input_file, output_file, fields_to_change)
    input_file.close()
    output_file.close()
コード例 #6
0
ファイル: idchanger.py プロジェクト: gvk489/liam2
def change_ids(input_path, output_path, changes, shuffle=False):
    with tables.open_file(input_path) as input_file:
        input_entities = input_file.root.entities
        print(" * indexing entities tables")
        idmaps = {}
        for ent_name in changes.iterkeys():
            print("    -", ent_name, "...", end=' ')
            start_time = time.time()
            table = getattr(input_entities, ent_name)

            new_ids = get_shrink_dict(table.col('id'), shuffle=shuffle)
            if -1 in new_ids:
                raise Exception('found id == -1 in %s which is invalid (link '
                                'columns can be -1)' % ent_name)
            # -1 links should stay -1
            new_ids[-1] = -1
            idmaps[ent_name] = new_ids
            print("done (%s elapsed)." % time2str(time.time() - start_time))

    print(" * modifying ids")
    fields_maps = {ent_name: {'id': idmaps[ent_name]} for ent_name in changes}
    for ent_name, fields in changes.iteritems():
        for target_ent, fname in fields:
            fields_maps[ent_name][fname] = idmaps[target_ent]
    h5_apply_rec_map(input_path, output_path, {'entities': fields_maps})
コード例 #7
0
ファイル: simulation.py プロジェクト: leeseungho90/liam2
        def simulate_period(period_idx, period, processes, entities,
                            init=False):
            print("\nperiod", period)
            if init:
                for entity in entities:
                    print("  * %s: %d individuals" % (entity.name,
                                                      len(entity.array)))
            else:
                print("- loading input data")
                for entity in entities:
                    print("  *", entity.name, "...", end=' ')
                    timed(entity.load_period_data, period)
                    print("    -> %d individuals" % len(entity.array))
            for entity in entities:
                entity.array_period = period
                entity.array['period'] = period

            if processes:
                # build context for this period:
                const_dict = {'__simulation__': self,
                              'period': period,
                              'nan': float('nan'),
                              '__globals__': globals_data}

                num_processes = len(processes)
                for p_num, process_def in enumerate(processes, start=1):
                    process, periodicity = process_def

                    print("- %d/%d" % (p_num, num_processes), process.name,
                          end=' ')
                    print("...", end=' ')
                    if period_idx % periodicity == 0:
                        elapsed, _ = gettime(process.run_guarded, self,
                                             const_dict)
                    else:
                        elapsed = 0
                        print("skipped (periodicity)")

                    process_time[process.name] += elapsed
                    if config.show_timings:
                        print("done (%s elapsed)." % time2str(elapsed))
                    else:
                        print("done.")
                    self.start_console(process.entity, period,
                                       globals_data)

            print("- storing period data")
            for entity in entities:
                print("  *", entity.name, "...", end=' ')
                timed(entity.store_period_data, period)
                print("    -> %d individuals" % len(entity.array))
#            print " - compressing period data"
#            for entity in entities:
#                print "  *", entity.name, "...",
#                for level in range(1, 10, 2):
#                    print "   %d:" % level,
#                    timed(entity.compress_period_data, level)
            period_objects[period] = sum(len(entity.array)
                                         for entity in entities)
コード例 #8
0
ファイル: render.py プロジェクト: hughplay/TVR
 def render_seg(self):
     self.info('[Render] seg ({})...'.format(self.current_stage))
     for key, seg_path in self.segs_path.items():
         time_seg_used = self.scene.render_shadeless(
             seg_path, self.scene.cameras[key], config.seg_width,
             config.seg_height)
         self.info('- {}: {}'.format(key, utils.time2str(time_seg_used)))
         if not os.path.isfile(seg_path):
             return False
     return True
コード例 #9
0
ファイル: simulation.py プロジェクト: abozio/Myliam2
        def simulate_period(period_idx, period, processes, entities, init=False):
            print "\nperiod", period
            if init:
                for entity in entities:
                    print "  * %s: %d individuals" % (entity.name, len(entity.array))
            else:
                print "- loading input data"
                for entity in entities:
                    print "  *", entity.name, "...",
                    timed(entity.load_period_data, period)
                    print "    -> %d individuals" % len(entity.array)
            for entity in entities:
                entity.array_period = period
                entity.array["period"] = period

            if processes:
                # build context for this period:
                const_dict = {"period": period, "nan": float("nan"), "__globals__": globals_data}

                num_processes = len(processes)
                for p_num, process_def in enumerate(processes, start=1):
                    process, periodicity = process_def

                    print "- %d/%d" % (p_num, num_processes), process.name,
                    # TODO: provide a custom __str__ method for Process &
                    # Assignment instead
                    if hasattr(process, "predictor") and process.predictor and process.predictor != process.name:
                        print "(%s)" % process.predictor,
                    print "...",
                    if period_idx % periodicity == 0:
                        elapsed, _ = gettime(process.run_guarded, self, const_dict)
                    else:
                        elapsed = 0
                        print "skipped (periodicity)"

                    process_time[process.name] += elapsed
                    if config.show_timings:
                        print "done (%s elapsed)." % time2str(elapsed)
                    else:
                        print "done."
                    self.start_console(process.entity, period, globals_data)

            print "- storing period data"
            for entity in entities:
                print "  *", entity.name, "...",
                timed(entity.store_period_data, period)
                print "    -> %d individuals" % len(entity.array)
            #            print " - compressing period data"
            #            for entity in entities:
            #                print "  *", entity.name, "...",
            #                for level in range(1, 10, 2):
            #                    print "   %d:" % level,
            #                    timed(entity.compress_period_data, level)
            period_objects[period] = sum(len(entity.array) for entity in entities)
コード例 #10
0
ファイル: simulation.py プロジェクト: fagan2888/liam2
def show_top_times(what, times, count):
    """
    >>> show_top_times("letters", [('a', 0.1), ('b', 0.2)], 5)
    top 5 letters:
     - a: 0.10 second (33%)
     - b: 0.20 second (66%)
    total for top 5 letters: 0.30 second
    >>> show_top_times("zeros", [('a', 0)], 5)
    top 5 zeros:
     - a: 0 ms (100%)
    total for top 5 zeros: 0 ms
    """
    total = sum(t for n, t in times)
    print("top %d %s:" % (count, what))
    for name, timing in times[:count]:
        try:
            percent = 100.0 * timing / total
        except ZeroDivisionError:
            percent = 100
        print(" - %s: %s (%d%%)" % (name, time2str(timing), percent))
    print("total for top %d %s:" % (count, what), end=' ')
    print(time2str(sum(timing for name, timing in times[:count])))
コード例 #11
0
ファイル: simulation.py プロジェクト: TaxIPP-Life/liam2
def show_top_times(what, times, count):
    """
    >>> show_top_times("letters", [('a', 0.1), ('b', 0.2)], 5)
    top 5 letters:
     - a: 0.10 second (33%)
     - b: 0.20 second (66%)
    total for top 5 letters: 0.30 second
    >>> show_top_times("zeros", [('a', 0)], 5)
    top 5 zeros:
     - a: 0 ms (100%)
    total for top 5 zeros: 0 ms
    """
    total = sum(t for n, t in times)
    print("top %d %s:" % (count, what))
    for name, timing in times[:count]:
        try:
            percent = 100.0 * timing / total
        except ZeroDivisionError:
            percent = 100
        print(" - %s: %s (%d%%)" % (name, time2str(timing), percent))
    print("total for top %d %s:" % (count, what), end=' ')
    print(time2str(sum(timing for name, timing in times[:count])))
コード例 #12
0
ファイル: render.py プロジェクト: hughplay/TVR
 def render_main(self):
     if not self.config.no_render:
         self.info('[Render] main ({})...'.format(self.current_stage))
         for key, image_path in self.images_path.items():
             if not os.path.isfile(image_path):
                 time_used = self.scene.render(image_path,
                                               self.scene.cameras[key],
                                               config.width, config.height)
                 self.info('- {}: {}'.format(key,
                                             utils.time2str(time_used)))
                 if not os.path.isfile(image_path):
                     return False
     return True
コード例 #13
0
ファイル: data.py プロジェクト: jonathangoupille/Myliam2
def index_tables(globals_fields, entities, fpath):
    print "reading data from %s ..." % fpath

    input_file = tables.openFile(fpath, mode="r")
    try:
        periodic_globals = None
        input_root = input_file.root

        if globals_fields:
            if ('globals' not in input_root or 
                'periodic' not in input_root.globals):
                raise Exception('could not find globals in the input data '
                                'file (but they are declared in the '
                                'simulation file)')
            globals_table = input_root.globals.periodic
            # load globals in memory
            #FIXME: make sure either period or PERIOD is present
            assertValidFields(globals_table, globals_fields,
                              allowed_missing=('period', 'PERIOD'))
            periodic_globals = globals_table.read()

        input_entities = input_root.entities

        entities_tables = {}
        dataset = {'globals': periodic_globals,
                   'entities': entities_tables}

        print " * indexing tables"
        for ent_name, entity in entities.iteritems():
            print "    -", ent_name, "...",

            table = getattr(input_entities, ent_name)
            assertValidFields(table, entity.fields, entity.missing_fields)

            start_time = time.time()
            rows_per_period, id_to_rownum_per_period = index_table(table)
            indexed_table = IndexedTable(table, rows_per_period,
                                         id_to_rownum_per_period)
            entities_tables[ent_name] = indexed_table
            print "done (%s elapsed)." % time2str(time.time() - start_time)
    except:
        input_file.close()
        raise

    return input_file, dataset
コード例 #14
0
 def play_bar(self, screen, total, pos):
     ''' add usable playbar to screen showing the progress of the actual played 
         audio file'''
     # calculate factor for progress
     factor = str2time(total) / 300
     if pos == -1: pos = 0
     # write Time string to display
     screen.blit(self.font.render('Time:', True, WHITE), (10, 135))
     screen.blit(self.font.render(time2str(pos) + '/' + total, True, WHITE),
                 (70, 135))
     # draw actual position as moving vertical rectangle to screen
     pygame.draw.rect(screen, WHITE,
                      pygame.Rect((pos / factor) + 10, 160, 5, 25), 0)
     # draw actual progress on screen
     pygame.draw.rect(screen, WHITE, pygame.Rect(10, 165, (pos / factor),
                                                 15), 0)
     # draw frame of Playbar to screen
     pygame.draw.rect(screen, WHITE, pygame.Rect(10, 165, 300, 15), 1)
     return screen
コード例 #15
0
ファイル: graphics.py プロジェクト: psikon/pitft-scripts
 def play_bar(self, screen, total, pos):
     ''' add usable playbar to screen showing the progress of the actual played 
         audio file'''
     # calculate factor for progress
     factor = str2time(total)/300
     if pos == -1: pos = 0
     # write Time string to display
     screen.blit(self.font.render('Time:', True, WHITE), (10, 135))
     screen.blit(self.font.render(time2str(pos) + '/' + total, True, 
         WHITE), (70, 135))
     # draw actual position as moving vertical rectangle to screen
     pygame.draw.rect(screen, WHITE, 
         pygame.Rect((pos/factor) + 10, 160, 5, 25), 0)
     # draw actual progress on screen
     pygame.draw.rect(screen, WHITE, 
         pygame.Rect(10, 165, (pos/factor), 15), 0)
     # draw frame of Playbar to screen
     pygame.draw.rect(screen, WHITE, 
         pygame.Rect(10, 165, 300, 15), 1)
     return screen
コード例 #16
0
def main():
    args = parser.parse_args()

    # initialize global path
    global_path = os.path.dirname(os.path.realpath(__file__))
    conf.global_path = global_path
    print('the global path: '.format(global_path))

    # configure the logging path.
    conf.time_id = time2str()
    conf.logging_path = join(global_path, './logs', conf.time_id)
    conf.writting_path = join(conf.logging_path, './logging')

    # configure checkpoint for images and models.
    conf.image_directory = join(conf.logging_path, conf.IMAGE_SAVING_DIRECTORY)
    conf.model_directory = join(conf.logging_path,
                                conf.MODEL_SAVEING_DIRECTORY)
    build_dirs(conf.image_directory)
    build_dirs(conf.model_directory)
    build_dirs(conf.writting_path)
    conf.writer = tensorboardX.SummaryWriter(conf.writting_path)

    # Setting parameters
    conf.max_epochs = args.epochs
    print('number epochs: {}'.format(conf.max_epochs))
    conf.num_data_workers = args.workers
    print('number of workers: {}'.format(conf.num_data_workers))
    conf.lr = args.learning_rate
    print('learning rate: {}'.format(conf.lr))
    conf.batch_size = args.batch_size
    print('batch size: {}'.format(conf.batch_size))
    conf.max_iterations = args.max_iterations
    print('max number of iterations: {}'.format(conf.max_iterations))
    conf.n_critic = args.number_critic
    print('number of critic training: {}'.format(conf.n_critic))
    conf.gp_lambda = args.gp_lambda
    print('gradient penalty weight: {}'.format(conf.gp_lambda))

    train(conf)
コード例 #17
0
ファイル: data.py プロジェクト: AlexisEidelman/liam2_mahdi
def index_tables(globals_def, entities, fpath):
    print("reading data from %s ..." % fpath)

    input_file = tables.open_file(fpath)
    try:
        input_root = input_file.root

        if any('path' not in g_def for g_def in globals_def.itervalues()) and \
                'globals' not in input_root:
            raise Exception('could not find any globals in the input data file '
                            '(but some are declared in the simulation file)')

        globals_data = load_path_globals(globals_def)

        globals_node = getattr(input_root, 'globals', None)
        for name, global_def in globals_def.iteritems():
            # already loaded from another source (path)
            if name in globals_data:
                continue

            if name not in globals_node:
                raise Exception("could not find 'globals/%s' in the input "
                                "data file" % name)

            global_data = getattr(globals_node, name)

            global_type = global_def.get('type', global_def.get('fields'))
            # TODO: move the checking (assertValidType) to a separate function
            assert_valid_type(global_data, global_type, context=name)
            array = global_data.read()
            if isinstance(global_type, list):
                # make sure we do not keep in memory columns which are
                # present in the input file but where not asked for by the
                # modeller. They are not accessible anyway.
                array = add_and_drop_fields(array, global_type)
            attrs = global_data.attrs
            dim_names = getattr(attrs, 'dimensions', None)
            if dim_names is not None:
                # we serialise dim_names as a numpy array so that it is
                # stored as a native hdf type and not a pickle but we
                # prefer to work with simple lists
                dim_names = list(dim_names)
                pvalues = [getattr(attrs, 'dim%d_pvalues' % i)
                           for i in range(len(dim_names))]
                array = LabeledArray(array, dim_names, pvalues)
            globals_data[name] = array

        input_entities = input_root.entities

        entities_tables = {}
        print(" * indexing tables")
        for ent_name, entity in entities.iteritems():
            print("    -", ent_name, "...", end=' ')

            table = getattr(input_entities, ent_name)
            assert_valid_type(table, entity.fields, entity.missing_fields)

            start_time = time.time()
            rows_per_period, id_to_rownum_per_period = index_table(table)
            indexed_table = IndexedTable(table, rows_per_period,
                                         id_to_rownum_per_period)
            entities_tables[ent_name] = indexed_table
            print("done (%s elapsed)." % time2str(time.time() - start_time))
    except:
        input_file.close()
        raise

    return input_file, {'globals': globals_data, 'entities': entities_tables}
コード例 #18
0
ファイル: simulation.py プロジェクト: TaxIPP-Life/liam2
        def simulate_period(period_idx, period, periods, processes, entities,
                            init=False):
            period_start_time = time.time()

            # set current period
            eval_ctx.period = period

            if config.log_level in ("procedures", "processes"):
                print()
            print("period", period,
                  end=" " if config.log_level == "periods" else "\n")
            if init and config.log_level in ("procedures", "processes"):
                for entity in entities:
                    print("  * %s: %d individuals" % (entity.name,
                                                      len(entity.array)))
            else:
                if config.log_level in ("procedures", "processes"):
                    print("- loading input data")
                    for entity in entities:
                        print("  *", entity.name, "...", end=' ')
                        timed(entity.load_period_data, period)
                        print("    -> %d individuals" % len(entity.array))
                else:
                    for entity in entities:
                        entity.load_period_data(period)
            for entity in entities:
                entity.array_period = period
                entity.array['period'] = period

            if processes:
                # build context for this period:
                const_dict = {'period_idx': period_idx + 1,
                              'periods': periods,
                              'periodicity': time_period[self.time_scale] * (1 - 2 * (self.retro)),
                              'longitudinal': self.longitudinal,
                              'format_date': self.time_scale,
                              'pension': None,
                              '__simulation__': self,
                              'period': period,
                              'nan': float('nan'),
                              '__globals__': globals_data}
                assert(periods[period_idx + 1] == period)

                num_processes = len(processes)
                for p_num, process_def in enumerate(processes, start=1):

                    process, periodicity, start = process_def
                    if config.log_level in ("procedures", "processes"):
                        print("- %d/%d" % (p_num, num_processes), process.name,
                              end=' ')
                        print("...", end=' ')
                    # TDOD: change that
                    if isinstance(periodicity, int):
                        if period_idx % periodicity == 0:
                            elapsed, _ = gettime(process.run_guarded, self,
                                                 const_dict)
                        else:
                            elapsed = 0
                            print("skipped (periodicity)")
                    else:
                        assert periodicity in time_period
                        periodicity_process = time_period[periodicity]
                        periodicity_simul = time_period[self.time_scale]
                        month_idx = period % 100
                        # first condition, to run a process with start == 12
                        # each year even if year are yyyy01
                        # modify start if periodicity_simul is not month
                        start = int(start / periodicity_simul - 0.01) * periodicity_simul + 1

                        if (periodicity_process <= periodicity_simul and self.time_scale != 'year0') or (
                                month_idx % periodicity_process == start % periodicity_process):

                            const_dict['periodicity'] = periodicity_process * (1 - 2 * (self.retro))
                            elapsed, _ = gettime(process.run_guarded, self, const_dict)
                        else:
                            elapsed = 0

                        if config.log_level in ("procedures", "processes"):
                            print("skipped (periodicity)")

                    process_time[process.name] += elapsed
                    if config.log_level in ("procedures", "processes"):
                        if config.show_timings:
                            print("done (%s elapsed)." % time2str(elapsed))
                        else:
                            print("done.")
                    self.start_console(eval_ctx)

            # update longitudinal
            person = [x for x in entities if x.name == 'person'][0]
            # maybe we have a get_entity or anything more nice than that #TODO: check
            id = person.array.columns['id']

            for varname in ['sali', 'workstate']:
                var = person.array.columns[varname]
                if init:
                    fpath = self.data_source.input_path
                    input_file = HDFStore(fpath, mode="r")
                    if 'longitudinal' in input_file.root:
                        input_longitudinal = input_file.root.longitudinal
                        if varname in input_longitudinal:
                            self.longitudinal[varname] = input_file['/longitudinal/' + varname]
                            if period not in self.longitudinal[varname].columns:
                                table = DataFrame({'id': id, period: var})
                                self.longitudinal[varname] = self.longitudinal[varname].merge(
                                    table, on='id', how='outer')
                        else:
                            # when one variable is not in the input_file
                            self.longitudinal[varname] = DataFrame({'id': id, period: var})
                    else:
                        # when there is no longitudinal in the dataset
                        self.longitudinal[varname] = DataFrame({'id': id, period: var})
                else:
                    table = DataFrame({'id': id, period: var})
                    if period in self.longitudinal[varname]:
                        import pdb
                        pdb.set_trace()
                    self.longitudinal[varname] = self.longitudinal[varname].merge(table, on='id', how='outer')

            if config.log_level in ("procedures", "processes"):
                print("- storing period data")
                for entity in entities:
                    print("  *", entity.name, "...", end=' ')
                    timed(entity.store_period_data, period)
                    print("    -> %d individuals" % len(entity.array))
            else:
                for entity in entities:
                    entity.store_period_data(period)

#            print " - compressing period data"
#            for entity in entities:
#                print "  *", entity.name, "...",
#                for level in range(1, 10, 2):
#                    print "   %d:" % level,
#                    timed(entity.compress_period_data, level)
            period_objects[period] = sum(len(entity.array)
                                         for entity in entities)
            period_elapsed_time = time.time() - period_start_time
            if config.log_level in ("procedures", "processes"):
                print("period %d" % period, end=' ')
            print("done", end=' ')
            if config.show_timings:
                print("(%s elapsed)" % time2str(period_elapsed_time), end="")
                if init:
                    print(".")
                else:
                    main_elapsed_time = time.time() - main_start_time
                    periods_done = period_idx + 1
                    remaining_periods = self.periods - periods_done
                    avg_time = main_elapsed_time / periods_done
                    # future_time = period_elapsed_time * 0.4 + avg_time * 0.6
                    remaining_time = avg_time * remaining_periods
                    print(" - estimated remaining time: %s."
                          % time2str(remaining_time))
            else:
                print()
コード例 #19
0
ファイル: data.py プロジェクト: LouisePaulDelvaux/liam2
def index_tables(globals_def, entities, fpath):
    print("reading data from %s ..." % fpath)

    input_file = tables.openFile(fpath, mode="r")
    try:
        input_root = input_file.root

        #TODO: move the checking (assertValidType) to a separate function
        globals_data = {}
        if globals_def:
            if 'globals' not in input_root:
                raise Exception('could not find any globals in the input data '
                                'file (but they are declared in the '
                                'simulation file)')
            globals_node = input_root.globals
            for name, global_type in globals_def.iteritems():
                if name not in globals_node:
                    raise Exception('could not find %s in the input data file '
                                    'globals' % name)

                global_data = getattr(globals_node, name)
                # load globals in memory
                if name == 'periodic':
                    periodic_fields = set(global_data.dtype.names)
                    if ('period' not in periodic_fields and
                        'PERIOD' not in periodic_fields):
                        raise Exception("Table 'periodic' in hdf5 input file "
                                        "is missing a field named 'PERIOD'")
                    allowed_missing = ('period', 'PERIOD')
                else:
                    allowed_missing = None
                assert_valid_type(global_data, global_type, allowed_missing,
                                name)
                array = global_data.read()
                attrs = global_data.attrs
                dim_names = getattr(attrs, 'dimensions', None)
                if dim_names is not None:
                    # we serialise dim_names as a numpy array so that it is
                    # stored as a native hdf type and not a pickle but we
                    # prefer to work with simple lists
                    dim_names = list(dim_names)
                    pvalues = [getattr(attrs, 'dim%d_pvalues' % i)
                               for i in range(len(dim_names))]
                    array = LabeledArray(array, dim_names, pvalues)
                globals_data[name] = array

        input_entities = input_root.entities

        entities_tables = {}
        print(" * indexing tables")
        for ent_name, entity in entities.iteritems():
            print("    -", ent_name, "...", end=' ')

            table = getattr(input_entities, ent_name)
            assert_valid_type(table, entity.fields, entity.missing_fields)

            start_time = time.time()
            rows_per_period, id_to_rownum_per_period = index_table(table)
            indexed_table = IndexedTable(table, rows_per_period,
                                         id_to_rownum_per_period)
            entities_tables[ent_name] = indexed_table
            print("done (%s elapsed)." % time2str(time.time() - start_time))
    except:
        input_file.close()
        raise

    return input_file, {'globals': globals_data, 'entities': entities_tables}
コード例 #20
0
ファイル: train.py プロジェクト: NougatCA/CodePtr
    def train_iter(self):

        loss = utils.AverageMeter()
        epoch_times = []

        for epoch in range(1, config.n_epochs + 1):
            epoch_time = utils.Timer()
            epoch_bar = tqdm(
                self.dataloader,
                desc='epoch: {}/{} [loss: {:.4f}, perplexity: {:.4f}]'.format(
                    1, config.n_epochs, 0.0000, 0.0000))

            for index_batch, batch in enumerate(epoch_bar):

                batch_size = batch.batch_size

                batch_loss = self.train_one_batch(batch, batch_size)
                loss.update(batch_loss, batch_size)

                epoch_bar.set_description(
                    'Epoch: {}/{} [loss: {:.4f}, perplexity: {:.4f}]'.format(
                        epoch, config.n_epochs, loss.avg, math.exp(loss.avg)))

                if index_batch % config.log_state_every == 0:
                    logger.debug(
                        'Epoch: {}/{}, time: {:.2f}, loss: {:.4f}, perplexity: {:.4f}'
                        .format(epoch, config.n_epochs, epoch_time.time(),
                                loss.avg, math.exp(loss.avg)))

            epoch_time.stop()
            epoch_times.append(epoch_time.time())

            logger.info(
                'Epoch {} finished, time: {:.2f}, loss: {:.4f}, perplexity: {:.4f}'
                .format(epoch, epoch_time.time(), loss.avg,
                        math.exp(loss.avg)))

            loss.reset()

            self.validate(epoch)

            if config.use_early_stopping:
                if self.early_stopping.early_stop:
                    break

            if config.use_lr_decay:
                self.lr_scheduler.step()

            logger.debug('learning rate: {:.6f}'.format(
                self.optimizer.param_groups[0]['lr']))

        logger.info(
            'Training finished, best model at the end of epoch {}'.format(
                self.early_stopping.best_epoch))

        # save best model, i.e. model with min valid loss
        path = self.save_model(name='train.best.pt',
                               state_dict=self.early_stopping.best_model)
        logger.info('Best model is saved as {}'.format(path))

        # time statics
        avg_epoch_time = np.mean(epoch_times)
        logger.info('Average time consumed by each epoch: {}'.format(
            utils.time2str(avg_epoch_time)))
コード例 #21
0
ファイル: simulation.py プロジェクト: TaxIPP-Life/liam2
    def run(self, run_console=False):
        start_time = time.time()

        h5in, h5out, globals_data = timed(self.data_source.run,
                                          self.globals_def,
                                          entity_registry,
                                          self.init_period)

        if config.autodump or config.autodiff:
            if config.autodump:
                fname, _ = config.autodump
                mode = 'w'
            else:  # config.autodiff
                fname, _ = config.autodiff
                mode = 'r'
            fpath = os.path.join(config.output_directory, fname)
            h5_autodump = tables.open_file(fpath, mode=mode)
            config.autodump_file = h5_autodump
        else:
            h5_autodump = None

#        input_dataset = self.data_source.run(self.globals_def,
#                                             entity_registry)
#        output_dataset = self.data_sink.prepare(self.globals_def,
#                                                entity_registry)
#        output_dataset.copy(input_dataset, self.init_period - 1)
#        for entity in input_dataset:
#            indexed_array = buildArrayForPeriod(entity)

        # tell numpy we do not want warnings for x/0 and 0/0
        np.seterr(divide='ignore', invalid='ignore')

        process_time = defaultdict(float)
        period_objects = {}
        eval_ctx = EvaluationContext(self, self.entities_map, globals_data)

        def simulate_period(period_idx, period, periods, processes, entities,
                            init=False):
            period_start_time = time.time()

            # set current period
            eval_ctx.period = period

            if config.log_level in ("procedures", "processes"):
                print()
            print("period", period,
                  end=" " if config.log_level == "periods" else "\n")
            if init and config.log_level in ("procedures", "processes"):
                for entity in entities:
                    print("  * %s: %d individuals" % (entity.name,
                                                      len(entity.array)))
            else:
                if config.log_level in ("procedures", "processes"):
                    print("- loading input data")
                    for entity in entities:
                        print("  *", entity.name, "...", end=' ')
                        timed(entity.load_period_data, period)
                        print("    -> %d individuals" % len(entity.array))
                else:
                    for entity in entities:
                        entity.load_period_data(period)
            for entity in entities:
                entity.array_period = period
                entity.array['period'] = period

            if processes:
                # build context for this period:
                const_dict = {'period_idx': period_idx + 1,
                              'periods': periods,
                              'periodicity': time_period[self.time_scale] * (1 - 2 * (self.retro)),
                              'longitudinal': self.longitudinal,
                              'format_date': self.time_scale,
                              'pension': None,
                              '__simulation__': self,
                              'period': period,
                              'nan': float('nan'),
                              '__globals__': globals_data}
                assert(periods[period_idx + 1] == period)

                num_processes = len(processes)
                for p_num, process_def in enumerate(processes, start=1):

                    process, periodicity, start = process_def
                    if config.log_level in ("procedures", "processes"):
                        print("- %d/%d" % (p_num, num_processes), process.name,
                              end=' ')
                        print("...", end=' ')
                    # TDOD: change that
                    if isinstance(periodicity, int):
                        if period_idx % periodicity == 0:
                            elapsed, _ = gettime(process.run_guarded, self,
                                                 const_dict)
                        else:
                            elapsed = 0
                            print("skipped (periodicity)")
                    else:
                        assert periodicity in time_period
                        periodicity_process = time_period[periodicity]
                        periodicity_simul = time_period[self.time_scale]
                        month_idx = period % 100
                        # first condition, to run a process with start == 12
                        # each year even if year are yyyy01
                        # modify start if periodicity_simul is not month
                        start = int(start / periodicity_simul - 0.01) * periodicity_simul + 1

                        if (periodicity_process <= periodicity_simul and self.time_scale != 'year0') or (
                                month_idx % periodicity_process == start % periodicity_process):

                            const_dict['periodicity'] = periodicity_process * (1 - 2 * (self.retro))
                            elapsed, _ = gettime(process.run_guarded, self, const_dict)
                        else:
                            elapsed = 0

                        if config.log_level in ("procedures", "processes"):
                            print("skipped (periodicity)")

                    process_time[process.name] += elapsed
                    if config.log_level in ("procedures", "processes"):
                        if config.show_timings:
                            print("done (%s elapsed)." % time2str(elapsed))
                        else:
                            print("done.")
                    self.start_console(eval_ctx)

            # update longitudinal
            person = [x for x in entities if x.name == 'person'][0]
            # maybe we have a get_entity or anything more nice than that #TODO: check
            id = person.array.columns['id']

            for varname in ['sali', 'workstate']:
                var = person.array.columns[varname]
                if init:
                    fpath = self.data_source.input_path
                    input_file = HDFStore(fpath, mode="r")
                    if 'longitudinal' in input_file.root:
                        input_longitudinal = input_file.root.longitudinal
                        if varname in input_longitudinal:
                            self.longitudinal[varname] = input_file['/longitudinal/' + varname]
                            if period not in self.longitudinal[varname].columns:
                                table = DataFrame({'id': id, period: var})
                                self.longitudinal[varname] = self.longitudinal[varname].merge(
                                    table, on='id', how='outer')
                        else:
                            # when one variable is not in the input_file
                            self.longitudinal[varname] = DataFrame({'id': id, period: var})
                    else:
                        # when there is no longitudinal in the dataset
                        self.longitudinal[varname] = DataFrame({'id': id, period: var})
                else:
                    table = DataFrame({'id': id, period: var})
                    if period in self.longitudinal[varname]:
                        import pdb
                        pdb.set_trace()
                    self.longitudinal[varname] = self.longitudinal[varname].merge(table, on='id', how='outer')

            if config.log_level in ("procedures", "processes"):
                print("- storing period data")
                for entity in entities:
                    print("  *", entity.name, "...", end=' ')
                    timed(entity.store_period_data, period)
                    print("    -> %d individuals" % len(entity.array))
            else:
                for entity in entities:
                    entity.store_period_data(period)

#            print " - compressing period data"
#            for entity in entities:
#                print "  *", entity.name, "...",
#                for level in range(1, 10, 2):
#                    print "   %d:" % level,
#                    timed(entity.compress_period_data, level)
            period_objects[period] = sum(len(entity.array)
                                         for entity in entities)
            period_elapsed_time = time.time() - period_start_time
            if config.log_level in ("procedures", "processes"):
                print("period %d" % period, end=' ')
            print("done", end=' ')
            if config.show_timings:
                print("(%s elapsed)" % time2str(period_elapsed_time), end="")
                if init:
                    print(".")
                else:
                    main_elapsed_time = time.time() - main_start_time
                    periods_done = period_idx + 1
                    remaining_periods = self.periods - periods_done
                    avg_time = main_elapsed_time / periods_done
                    # future_time = period_elapsed_time * 0.4 + avg_time * 0.6
                    remaining_time = avg_time * remaining_periods
                    print(" - estimated remaining time: %s."
                          % time2str(remaining_time))
            else:
                print()

        print("""
=====================
 starting simulation
=====================""")
        try:
            assert(self.time_scale in time_period)
            month_periodicity = time_period[self.time_scale]
            time_direction = 1 - 2 * (self.retro)
            time_step = month_periodicity * time_direction

            periods = [
                self.init_period + int(t / 12) * 100 + t % 12
                for t in range(0, (self.periods + 1) * time_step, time_step)
                ]
            if self.time_scale == 'year0':
                periods = [self.init_period + t for t in range(0, (self.periods + 1))]
            print("simulated period are going to be: ", periods)

            init_start_time = time.time()
            simulate_period(0, self.init_period, [None, periods[0]], self.init_processes, self.entities, init=True)

            time_init = time.time() - init_start_time
            main_start_time = time.time()

            for period_idx, period in enumerate(periods[1:]):
                period_start_time = time.time()
                simulate_period(period_idx, period, periods,
                                self.processes, self.entities)

#                 if self.legislation:
#                     if not self.legislation['ex_post']:
#
#                         elapsed, _ = gettime(liam2of.main,period)
#                         process_time['liam2of'] += elapsed
#                         elapsed, _ = gettime(of_on_liam.main,self.legislation['annee'],[period])
#                         process_time['legislation'] += elapsed
#                         elapsed, _ = gettime(merge_leg.merge_h5,self.data_source.output_path,
#                                              "C:/Til/output/"+"simul_leg.h5",period)
#                         process_time['merge_leg'] += elapsed

                time_elapsed = time.time() - period_start_time
                print("period %d done" % period, end=' ')
                if config.show_timings:
                    print("(%s elapsed)." % time2str(time_elapsed))
                else:
                    print()

            total_objects = sum(period_objects[period] for period in periods)
            total_time = time.time() - main_start_time

#             if self.legislation:
#                 if self.legislation['ex_post']:
#
#                     elapsed, _ = gettime(liam2of.main)
#                     process_time['liam2of'] += elapsed
#                     elapsed, _ = gettime(of_on_liam.main,self.legislation['annee'])
#                     process_time['legislation'] += elapsed
#                     # TODO: faire un programme a part, so far ca ne marche pas pour l'ensemble
#                     # adapter n'est pas si facile, comme on veut economiser une table,
#                     # on ne peut pas faire de append directement parce qu on met 2010 apres 2011
#                     # a un moment dans le calcul
#                     elapsed, _ = gettime(merge_leg.merge_h5,self.data_source.output_path,
#                                          "C:/Til/output/"+"simul_leg.h5",None)
#                     process_time['merge_leg'] += elapsed

            if self.final_stat:
                elapsed, _ = gettime(start, period)
                process_time['Stat'] += elapsed

            total_time = time.time() - main_start_time
            time_year = 0
            if len(periods) > 1:
                nb_year_approx = periods[-1] / 100 - periods[1] / 100
                if nb_year_approx > 0:
                    time_year = total_time / nb_year_approx

            try:
                ind_per_sec = str(int(total_objects / total_time))
            except ZeroDivisionError:
                ind_per_sec = 'inf'
            print("""
==========================================
 simulation done
==========================================
 * %s elapsed
 * %d individuals on average
 * %s individuals/s/period on average

 * %s second for init_process
 * %s time/period in average
 * %s time/year in average
==========================================
""" % (
                time2str(time.time() - start_time),
                total_objects / self.periods,
                ind_per_sec,
                time2str(time_init),
                time2str(total_time / self.periods),
                time2str(time_year))
            )

            show_top_processes(process_time, 10)
            # if config.debug:
            #     show_top_expr()

            if run_console:
                console_ctx = eval_ctx.clone(entity_name=self.default_entity)
                c = console.Console(console_ctx)
                c.run()

        finally:
            if h5in is not None:
                h5in.close()
            h5out.close()
            if h5_autodump is not None:
                h5_autodump.close()
コード例 #22
0
ファイル: simulation.py プロジェクト: abozio/Myliam2
    def run(self, run_console=False):
        start_time = time.time()
        h5in, h5out, globals_data = timed(
            self.data_source.run, self.globals_def, entity_registry, self.start_period - 1
        )
        #        input_dataset = self.data_source.run(self.globals_def,
        #                                             entity_registry)
        #        output_dataset = self.data_sink.prepare(self.globals_def,
        #                                                entity_registry)
        #        output_dataset.copy(input_dataset, self.start_period - 1)
        #        for entity in input_dataset:
        #            indexed_array = buildArrayForPeriod(entity)

        # tell numpy we do not want warnings for x/0 and 0/0
        np.seterr(divide="ignore", invalid="ignore")

        process_time = defaultdict(float)
        period_objects = {}

        def simulate_period(period_idx, period, processes, entities, init=False):
            print "\nperiod", period
            if init:
                for entity in entities:
                    print "  * %s: %d individuals" % (entity.name, len(entity.array))
            else:
                print "- loading input data"
                for entity in entities:
                    print "  *", entity.name, "...",
                    timed(entity.load_period_data, period)
                    print "    -> %d individuals" % len(entity.array)
            for entity in entities:
                entity.array_period = period
                entity.array["period"] = period

            if processes:
                # build context for this period:
                const_dict = {"period": period, "nan": float("nan"), "__globals__": globals_data}

                num_processes = len(processes)
                for p_num, process_def in enumerate(processes, start=1):
                    process, periodicity = process_def

                    print "- %d/%d" % (p_num, num_processes), process.name,
                    # TODO: provide a custom __str__ method for Process &
                    # Assignment instead
                    if hasattr(process, "predictor") and process.predictor and process.predictor != process.name:
                        print "(%s)" % process.predictor,
                    print "...",
                    if period_idx % periodicity == 0:
                        elapsed, _ = gettime(process.run_guarded, self, const_dict)
                    else:
                        elapsed = 0
                        print "skipped (periodicity)"

                    process_time[process.name] += elapsed
                    if config.show_timings:
                        print "done (%s elapsed)." % time2str(elapsed)
                    else:
                        print "done."
                    self.start_console(process.entity, period, globals_data)

            print "- storing period data"
            for entity in entities:
                print "  *", entity.name, "...",
                timed(entity.store_period_data, period)
                print "    -> %d individuals" % len(entity.array)
            #            print " - compressing period data"
            #            for entity in entities:
            #                print "  *", entity.name, "...",
            #                for level in range(1, 10, 2):
            #                    print "   %d:" % level,
            #                    timed(entity.compress_period_data, level)
            period_objects[period] = sum(len(entity.array) for entity in entities)

        try:
            simulate_period(0, self.start_period - 1, self.init_processes, self.entities, init=True)
            main_start_time = time.time()
            periods = range(self.start_period, self.start_period + self.periods)
            for period_idx, period in enumerate(periods):
                period_start_time = time.time()
                simulate_period(period_idx, period, self.processes, self.entities)
                time_elapsed = time.time() - period_start_time
                print "period %d done (%s elapsed)." % (period, time2str(time_elapsed))

            total_objects = sum(period_objects[period] for period in periods)
            total_time = time.time() - main_start_time
            print """
==========================================
 simulation done
==========================================
 * %s elapsed
 * %d individuals on average
 * %d individuals/s/period on average
==========================================
""" % (
                time2str(time.time() - start_time),
                total_objects / self.periods,
                total_objects / total_time,
            )

            show_top_processes(process_time, 10)
            #            if config.debug:
            #                show_top_expr()

            if run_console:
                c = console.Console(self.console_entity, periods[-1], self.globals_def, globals_data)
                c.run()

        finally:
            if h5in is not None:
                h5in.close()
            h5out.close()
コード例 #23
0
    def run(self, run_console=False):
        start_time = time.time()
        h5in, h5out, globals_data = timed(self.data_source.run,
                                          self.globals_def,
                                          self.entities_map,
                                          self.start_period - 1)

        if config.autodump or config.autodiff:
            if config.autodump:
                fname, _ = config.autodump
                mode = 'w'
            else:  # config.autodiff
                fname, _ = config.autodiff
                mode = 'r'
            fpath = os.path.join(config.output_directory, fname)
            h5_autodump = tables.open_file(fpath, mode=mode)
            config.autodump_file = h5_autodump
        else:
            h5_autodump = None

#        input_dataset = self.data_source.run(self.globals_def,
#                                             entity_registry)
#        output_dataset = self.data_sink.prepare(self.globals_def,
#                                                entity_registry)
#        output_dataset.copy(input_dataset, self.start_period - 1)
#        for entity in input_dataset:
#            indexed_array = build_period_array(entity)

        # tell numpy we do not want warnings for x/0 and 0/0
        np.seterr(divide='ignore', invalid='ignore')

        process_time = defaultdict(float)
        period_objects = {}
        eval_ctx = EvaluationContext(self, self.entities_map, globals_data)

        def simulate_period(period_idx, period, processes, entities,
                            init=False):
            period_start_time = time.time()

            # set current period
            eval_ctx.period = period

            if config.log_level in ("procedures", "processes"):
                print()
            print("period", period,
                  end=" " if config.log_level == "periods" else "\n")
            if init and config.log_level in ("procedures", "processes"):
                for entity in entities:
                    print("  * %s: %d individuals" % (entity.name,
                                                      len(entity.array)))
            else:
                if config.log_level in ("procedures", "processes"):
                    print("- loading input data")
                    for entity in entities:
                        print("  *", entity.name, "...", end=' ')
                        timed(entity.load_period_data, period)
                        print("    -> %d individuals" % len(entity.array))
                else:
                    for entity in entities:
                        entity.load_period_data(period)
            for entity in entities:
                entity.array_period = period
                entity.array['period'] = period

            if processes:
                num_processes = len(processes)
                for p_num, process_def in enumerate(processes, start=1):
                    process, periodicity = process_def

                    # set current entity
                    eval_ctx.entity_name = process.entity.name

                    if config.log_level in ("procedures", "processes"):
                        print("- %d/%d" % (p_num, num_processes), process.name,
                              end=' ')
                        print("...", end=' ')
                    if period_idx % periodicity == 0:
                        elapsed, _ = gettime(process.run_guarded, eval_ctx)
                    else:
                        elapsed = 0
                        if config.log_level in ("procedures", "processes"):
                            print("skipped (periodicity)")

                    process_time[process.name] += elapsed
                    if config.log_level in ("procedures", "processes"):
                        if config.show_timings:
                            print("done (%s elapsed)." % time2str(elapsed))
                        else:
                            print("done.")
                    self.start_console(eval_ctx)

            if config.log_level in ("procedures", "processes"):
                print("- storing period data")
                for entity in entities:
                    print("  *", entity.name, "...", end=' ')
                    timed(entity.store_period_data, period)
                    print("    -> %d individuals" % len(entity.array))
            else:
                for entity in entities:
                    entity.store_period_data(period)
#            print " - compressing period data"
#            for entity in entities:
#                print "  *", entity.name, "...",
#                for level in range(1, 10, 2):
#                    print "   %d:" % level,
#                    timed(entity.compress_period_data, level)
            period_objects[period] = sum(len(entity.array)
                                         for entity in entities)
            period_elapsed_time = time.time() - period_start_time
            if config.log_level in ("procedures", "processes"):
                print("period %d" % period, end=' ')
            print("done", end=' ')
            if config.show_timings:
                print("(%s elapsed)" % time2str(period_elapsed_time), end="")
                if init:
                    print(".")
                else:
                    main_elapsed_time = time.time() - main_start_time
                    periods_done = period_idx + 1
                    remaining_periods = self.periods - periods_done
                    avg_time = main_elapsed_time / periods_done
                    # future_time = period_elapsed_time * 0.4 + avg_time * 0.6
                    remaining_time = avg_time * remaining_periods
                    print(" - estimated remaining time: %s."
                          % time2str(remaining_time))
            else:
                print()

        print("""
=====================
 starting simulation
=====================""")
        try:
            simulate_period(0, self.start_period - 1, self.init_processes,
                            self.entities, init=True)
            main_start_time = time.time()
            periods = range(self.start_period,
                            self.start_period + self.periods)
            for period_idx, period in enumerate(periods):
                simulate_period(period_idx, period,
                                self.processes, self.entities)

            total_objects = sum(period_objects[period] for period in periods)
            total_time = time.time() - main_start_time
            try:
                ind_per_sec = str(int(total_objects / total_time))
            except ZeroDivisionError:
                ind_per_sec = 'inf'

            print("""
==========================================
 simulation done
==========================================
 * %s elapsed
 * %d individuals on average
 * %s individuals/s/period on average
==========================================
""" % (time2str(time.time() - start_time),
       total_objects / self.periods,
       ind_per_sec))

            show_top_processes(process_time, 10)
#            if config.debug:
#                show_top_expr()

            if run_console:
                console_ctx = eval_ctx.clone(entity_name=self.default_entity)
                c = console.Console(console_ctx)
                c.run()

        finally:
            if h5in is not None:
                h5in.close()
            h5out.close()
            if h5_autodump is not None:
                h5_autodump.close()
コード例 #24
0
ファイル: data.py プロジェクト: gvk489/liam2
    def prepare(self, globals_def, entities, input_dataset, start_period):
        """copy input (if any) to output and create output index"""
        output_file = tables.open_file(self.output_path, mode="w")

        try:
            globals_data = input_dataset.get('globals')
            if globals_data is not None:
                output_globals = output_file.create_group("/", "globals",
                                                          "Globals")
                for k, g_def in globals_def.iteritems():
                    if 'path' not in g_def:
                        anyarray_to_disk(output_globals, k, globals_data[k])

            entities_tables = input_dataset['entities']
            output_entities = output_file.create_group("/", "entities",
                                                       "Entities")
            output_file.create_group("/", "indexes", "Indexes")
            print(" * copying tables")
            for ent_name, entity in entities.iteritems():
                print("    -", ent_name, "...", end=' ')
                index_node = output_file.create_group("/indexes", ent_name)
                entity.output_index_node = index_node

                if not entity.fields.in_output:
                    print("skipped (no column in output)")
                    continue

                start_time = time.time()

                # main table
                table = entities_tables.get(ent_name)
                if table is not None:
                    input_rows = table.period_index
                    output_rows = dict((p, rows)
                                       for p, rows in input_rows.iteritems()
                                       if p < start_period)
                    if output_rows:
                        # stoprow = last row of the last period before
                        #           start_period
                        _, stoprow = input_rows[max(output_rows.iterkeys())]
                    else:
                        stoprow = 0

                    default_values = entity.fields.default_values
                    output_table = copy_table(table.table, output_entities,
                                              entity.fields.in_output.dtype,
                                              stop=stoprow,
                                              show_progress=True,
                                              default_values=default_values)
                    output_index = table.id2rownum_per_period.copy()
                else:
                    output_rows = {}
                    output_table = output_file.create_table(
                        output_entities, entity.name,
                        entity.fields.in_output.dtype,
                        title="%s table" % entity.name)
                    output_index = {}

                # entity.indexed_output_table = IndexedTable(output_table,
                #                                            output_rows,
                #                                            output_index)
                entity.output_index = output_index
                entity.output_rows = output_rows
                entity.table = output_table
                print("done (%s elapsed)." % time2str(time.time() - start_time))
        except:
            output_file.close()
            raise
        self.h5out = output_file
コード例 #25
0
    def train(self, feed, epochs=1):
        def get_iterations(epoch):
            D_iter = 5
            G_iter = 1
            return D_iter, G_iter

        # Saver for saving the model
        self.saver = tf.train.Saver()

        with tf.Session() as sess:

            # Init the variables
            sess.run(tf.global_variables_initializer())

            # Epochs
            for e in range(1, epochs + 1):

                # Time
                t0 = time.time()

                # Shuffle the data
                feed.shuffle_data()

                # Mini-batches propagation and back-propagation
                D_loss_train, G_loss_train = [], []
                D_iter, G_iter = get_iterations(e)
                while feed.has_more_data:
                    for _ in range(D_iter):
                        if not feed.has_more_data:
                            break

                        image_batch = feed.get_batch(augment=True, p_aug=0.5)
                        noise_batch = np.random.uniform(
                            -1,
                            1,
                            size=np.prod([len(image_batch),
                                          self.noise_dim])).reshape([
                                              len(image_batch), self.noise_dim
                                          ])
                        _, D_loss_batch = sess.run(
                            [self.disc_step, self.D_loss],
                            feed_dict={
                                self.noise: noise_batch,
                                self.image: image_batch,
                                self.is_training: True
                            })
                        D_loss_train.append(D_loss_batch)

                    for _ in range(G_iter):
                        if not feed.has_more_data:
                            break

                        image_batch = feed.get_batch()
                        noise_batch = np.random.uniform(
                            -1,
                            1,
                            size=np.prod([len(image_batch),
                                          self.noise_dim])).reshape([
                                              len(image_batch), self.noise_dim
                                          ])
                        _, G_loss_batch = sess.run(
                            [self.gen_step, self.G_loss],
                            feed_dict={
                                self.noise: noise_batch,
                                self.is_training: True
                            })
                        G_loss_train.append(G_loss_batch)

                # Time
                self.epochs_time.append(time.time() - t0)

                # Reset the data
                feed.reset()

                # Print train info
                str_template = 'Epoch: {}, D Train Loss: {}, G Train Loss: {}, Epoch Time: {}, Time to End: {}'
                print(
                    str_template.format(
                        e,
                        np.array(D_loss_train).mean(),
                        np.array(G_loss_train).mean(),
                        time2str(self.epochs_time[-1]),
                        time2str(
                            np.array(self.epochs_time).mean() * (epochs - e))))

                # View generation
                if e % 10 == 0:
                    images_gen = self.generate(num_images=64, seed=1234)

                    # Plot generated images
                    print('\tGeneration')
                    plt.figure(figsize=[20, 20])
                    for idx_image, image in enumerate(images_gen):
                        plt.subplot(int(np.ceil(np.sqrt(len(images_gen)))),
                                    int(np.ceil(np.sqrt(len(images_gen)))),
                                    idx_image + 1)
                        plt.imshow(image)
                        plt.xticks([])
                        plt.yticks([])
                    plt.show()

                # Save the model
                self.saver.save(sess, self.name)

        print('Train finished!')
コード例 #26
0
ファイル: data.py プロジェクト: jonathangoupille/Myliam2
    def run(self, globals_fields, entities, start_period):
        input_file, dataset = index_tables(globals_fields, entities,
                                           self.input_path)
        output_file = tables.openFile(self.output_path, mode="w")

        try:
            if dataset['globals'] is not None:
                output_globals = output_file.createGroup("/", "globals",
                                                         "Globals")
                copyTable(input_file.root.globals.periodic, output_file,
                          output_globals)

            entities_tables = dataset['entities']
            output_entities = output_file.createGroup("/", "entities",
                                                      "Entities")
            print " * copying tables"
            for ent_name, entity in entities.iteritems():
                print ent_name, "..."

                # main table

                table = entities_tables[ent_name]

                entity.input_index = table.id2rownum_per_period
                entity.input_rows = table.period_index
                entity.input_table = table.table
                entity.base_period = table.base_period

# this is what should happen
#                entity.indexed_input_table = entities_tables[ent_name]
#                entity.indexed_output_table = entities_tables[ent_name]

                #TODO: copying the table and generally preparing the output
                # file should be a different method than indexing
                print " * copying table..."
                start_time = time.time()
                input_rows = entity.input_rows
                output_rows = dict((p, rows)
                                   for p, rows in input_rows.iteritems()
                                   if p < start_period)
                if output_rows:
                    # stoprow = last row of the last period before start_period
                    _, stoprow = input_rows[max(output_rows.iterkeys())]
                else:
                    stoprow = 0

                output_table = copyTable(table.table,
                                         output_file, output_entities,
                                         entity.fields, stop=stoprow,
                                         show_progress=True)
                entity.output_rows = output_rows
                print "done (%s elapsed)." % time2str(time.time() - start_time)

                print " * building array for first simulated period...",
                start_time = time.time()

                #TODO: this whole process of merging all periods is very
                # opinionated and does not allow individuals to die/disappear
                # before the simulation starts. We couldn't for example,
                # take the output of one of our simulation and
                # re-simulate only some years in the middle, because the dead
                # would be brought back to life. In conclusion, it should be
                # optional.
                entity.array, entity.id_to_rownum = \
                    buildArrayForPeriod(table.table, entity.fields,
                                        entity.input_rows,
                                        entity.input_index, start_period)
                print "done (%s elapsed)." % time2str(time.time() - start_time)
                entity.table = output_table
        except:
            input_file.close()
            output_file.close()
            raise

        return input_file, output_file, dataset['globals']
コード例 #27
0
ファイル: simulation.py プロジェクト: leeseungho90/liam2
    def run(self, run_console=False):
        start_time = time.time()
        h5in, h5out, globals_data = timed(self.data_source.run,
                                          self.globals_def,
                                          entity_registry,
                                          self.start_period - 1)

        if config.autodump or config.autodiff:
            if config.autodump:
                fname, _ = config.autodump
                mode = 'w'
            else:  # config.autodiff
                fname, _ = config.autodiff
                mode = 'r'
            fpath = os.path.join(config.output_directory, fname)
            h5_autodump = tables.openFile(fpath, mode=mode)
            config.autodump_file = h5_autodump
        else:
            h5_autodump = None

#        input_dataset = self.data_source.run(self.globals_def,
#                                             entity_registry)
#        output_dataset = self.data_sink.prepare(self.globals_def,
#                                                entity_registry)
#        output_dataset.copy(input_dataset, self.start_period - 1)
#        for entity in input_dataset:
#            indexed_array = build_period_array(entity)

        # tell numpy we do not want warnings for x/0 and 0/0
        np.seterr(divide='ignore', invalid='ignore')

        process_time = defaultdict(float)
        period_objects = {}

        def simulate_period(period_idx, period, processes, entities,
                            init=False):
            print("\nperiod", period)
            if init:
                for entity in entities:
                    print("  * %s: %d individuals" % (entity.name,
                                                      len(entity.array)))
            else:
                print("- loading input data")
                for entity in entities:
                    print("  *", entity.name, "...", end=' ')
                    timed(entity.load_period_data, period)
                    print("    -> %d individuals" % len(entity.array))
            for entity in entities:
                entity.array_period = period
                entity.array['period'] = period

            if processes:
                # build context for this period:
                const_dict = {'__simulation__': self,
                              'period': period,
                              'nan': float('nan'),
                              '__globals__': globals_data}

                num_processes = len(processes)
                for p_num, process_def in enumerate(processes, start=1):
                    process, periodicity = process_def

                    print("- %d/%d" % (p_num, num_processes), process.name,
                          end=' ')
                    print("...", end=' ')
                    if period_idx % periodicity == 0:
                        elapsed, _ = gettime(process.run_guarded, self,
                                             const_dict)
                    else:
                        elapsed = 0
                        print("skipped (periodicity)")

                    process_time[process.name] += elapsed
                    if config.show_timings:
                        print("done (%s elapsed)." % time2str(elapsed))
                    else:
                        print("done.")
                    self.start_console(process.entity, period,
                                       globals_data)

            print("- storing period data")
            for entity in entities:
                print("  *", entity.name, "...", end=' ')
                timed(entity.store_period_data, period)
                print("    -> %d individuals" % len(entity.array))
#            print " - compressing period data"
#            for entity in entities:
#                print "  *", entity.name, "...",
#                for level in range(1, 10, 2):
#                    print "   %d:" % level,
#                    timed(entity.compress_period_data, level)
            period_objects[period] = sum(len(entity.array)
                                         for entity in entities)

        try:
            simulate_period(0, self.start_period - 1, self.init_processes,
                            self.entities, init=True)
            main_start_time = time.time()
            periods = range(self.start_period,
                            self.start_period + self.periods)
            for period_idx, period in enumerate(periods):
                period_start_time = time.time()
                simulate_period(period_idx, period,
                                self.processes, self.entities)
                time_elapsed = time.time() - period_start_time
                print("period %d done" % period, end=' ')
                if config.show_timings:
                    print("(%s elapsed)." % time2str(time_elapsed))
                else:
                    print()

            total_objects = sum(period_objects[period] for period in periods)
            total_time = time.time() - main_start_time
            try:
                ind_per_sec = str(int(total_objects / total_time))
            except ZeroDivisionError:
                ind_per_sec = 'inf'

            print("""
==========================================
 simulation done
==========================================
 * %s elapsed
 * %d individuals on average
 * %s individuals/s/period on average
==========================================
""" % (time2str(time.time() - start_time),
       total_objects / self.periods,
       ind_per_sec))

            show_top_processes(process_time, 10)
#            if config.debug:
#                show_top_expr()

            if run_console:
                c = console.Console(self.console_entity, periods[-1],
                                    self.globals_def, globals_data)
                c.run()

        finally:
            if h5in is not None:
                h5in.close()
            h5out.close()
            if h5_autodump is not None:
                h5_autodump.close()
コード例 #28
0
    def run(self, run_console=False):
        start_time = time.time()
        h5in, h5out, periodic_globals = timed(self.data_source.run,
                                              self.globals_fields,
                                              entity_registry,
                                              self.start_period - 1)
#        input_dataset = self.data_source.run(self.globals_fields,
#                                             entity_registry)
#        output_dataset = self.data_sink.prepare(self.globals_fields,
#                                                entity_registry)
#        output_dataset.copy(input_dataset, self.start_period - 1)
#        for entity in input_dataset:
#            indexed_array = buildArrayForPeriod(entity)

        if periodic_globals is not None:
            try:
                globals_periods = periodic_globals['PERIOD']
            except ValueError:
                globals_periods = periodic_globals['period']
            globals_base_period = globals_periods[0]

        process_time = defaultdict(float)
        period_objects = {}

        def simulate_period(period_idx, period, processes, entities, init=False):
            print "\nperiod", period
            if init:
                for entity in entities:
                    print "  * %s: %d individuals" % (entity.name,
                                                      len(entity.array))
            else:
                print "- loading input data"
                for entity in entities:
                    print "  *", entity.name, "...",
                    timed(entity.load_period_data, period)
                    print "    -> %d individuals" % len(entity.array)
            for entity in entities:
                entity.array['period'] = period

            if processes:
                # build context for this period:
                const_dict = {'period': period,
                              'nan': float('nan')}

                # update "globals" with their value for this period
                if periodic_globals is not None:
                    globals_row = period - globals_base_period
                    if globals_row < 0:
                        #XXX: use missing values instead?
                        raise Exception('Missing globals data for period %d'
                                        % period)
                    period_globals = periodic_globals[globals_row]
                    const_dict.update((k, period_globals[k])
                                      for k in period_globals.dtype.names)
                    const_dict['__globals__'] = periodic_globals

                num_processes = len(processes)
                for p_num, process_def in enumerate(processes, start=1):
                    process, periodicity = process_def
                    print "- %d/%d" % (p_num, num_processes), process.name,
                    #TODO: provided a custom __str__ method for Process &
                    # Assignment instead
                    if hasattr(process, 'predictor') and process.predictor \
                       and process.predictor != process.name:
                        print "(%s)" % process.predictor,
                    print "...",
                    if period_idx % periodicity == 0: 
                        elapsed, _ = gettime(process.run_guarded, self,
                                             const_dict)
                    else:
                        elapsed = 0
                        print "skipped (periodicity)"
                 

                    process_time[process.name] += elapsed
                    print "done (%s elapsed)." % time2str(elapsed)
                    self.start_console(process.entity, period)

            print "- storing period data"
            for entity in entities:
                print "  *", entity.name, "...",
                timed(entity.store_period_data, period)
                print "    -> %d individuals" % len(entity.array)
#            print " - compressing period data"
#            for entity in entities:
#                print "  *", entity.name, "...",
#                for level in range(1, 10, 2):
#                    print "   %d:" % level,
#                    timed(entity.compress_period_data, level)
            period_objects[period] = sum(len(entity.array)
                                         for entity in entities)

        try:
            simulate_period(0,self.start_period - 1, self.init_processes,
                            self.entities, init=True)
            main_start_time = time.time()
            periods = range(self.start_period,
                            self.start_period + self.periods)
            for period_idx, period in enumerate(periods):
                period_start_time = time.time()
                simulate_period(period_idx, period, self.processes, self.entities)
                time_elapsed = time.time() - period_start_time
                print "period %d done (%s elapsed)." % (period,
                                                        time2str(time_elapsed))

            total_objects = sum(period_objects[period] for period in periods)
            total_time = time.time() - main_start_time
            print """
==========================================
 simulation done
==========================================
 * %s elapsed
 * %d individuals on average
 * %d individuals/s/period on average
==========================================
""" % (time2str(time.time() - start_time),
       total_objects / self.periods,
       total_objects / total_time)

            show_top_processes(process_time, 10)

            if run_console:
                if self.default_entity is not None:
                    entity = entity_registry[self.default_entity]
                else:
                    entity = None
                c = console.Console(entity, periods[-1])
                c.run()

        finally:
            if h5in is not None:
                h5in.close()
            h5out.close()
コード例 #29
0
        def simulate_period(period_idx, period, processes, entities, init=False):
            print "\nperiod", period
            if init:
                for entity in entities:
                    print "  * %s: %d individuals" % (entity.name,
                                                      len(entity.array))
            else:
                print "- loading input data"
                for entity in entities:
                    print "  *", entity.name, "...",
                    timed(entity.load_period_data, period)
                    print "    -> %d individuals" % len(entity.array)
            for entity in entities:
                entity.array['period'] = period

            if processes:
                # build context for this period:
                const_dict = {'period': period,
                              'nan': float('nan')}

                # update "globals" with their value for this period
                if periodic_globals is not None:
                    globals_row = period - globals_base_period
                    if globals_row < 0:
                        #XXX: use missing values instead?
                        raise Exception('Missing globals data for period %d'
                                        % period)
                    period_globals = periodic_globals[globals_row]
                    const_dict.update((k, period_globals[k])
                                      for k in period_globals.dtype.names)
                    const_dict['__globals__'] = periodic_globals

                num_processes = len(processes)
                for p_num, process_def in enumerate(processes, start=1):
                    process, periodicity = process_def
                    print "- %d/%d" % (p_num, num_processes), process.name,
                    #TODO: provided a custom __str__ method for Process &
                    # Assignment instead
                    if hasattr(process, 'predictor') and process.predictor \
                       and process.predictor != process.name:
                        print "(%s)" % process.predictor,
                    print "...",
                    if period_idx % periodicity == 0: 
                        elapsed, _ = gettime(process.run_guarded, self,
                                             const_dict)
                    else:
                        elapsed = 0
                        print "skipped (periodicity)"
                 

                    process_time[process.name] += elapsed
                    print "done (%s elapsed)." % time2str(elapsed)
                    self.start_console(process.entity, period)

            print "- storing period data"
            for entity in entities:
                print "  *", entity.name, "...",
                timed(entity.store_period_data, period)
                print "    -> %d individuals" % len(entity.array)
#            print " - compressing period data"
#            for entity in entities:
#                print "  *", entity.name, "...",
#                for level in range(1, 10, 2):
#                    print "   %d:" % level,
#                    timed(entity.compress_period_data, level)
            period_objects[period] = sum(len(entity.array)
                                         for entity in entities)
コード例 #30
0
ファイル: simulation.py プロジェクト: AnneDy/Til-Liam
        def simulate_period(period_idx, period, periods, processes, entities,
                            init=False):
            print("\nperiod", period)
            if init:
                for entity in entities:
                    print("  * %s: %d individuals" % (entity.name,
                                                      len(entity.array)))
            else:
                print("- loading input data")
                for entity in entities:
                    print("  *", entity.name, "...", end=' ')
                    timed(entity.load_period_data, period)
                    print("    -> %d individuals" % len(entity.array))
            for entity in entities:
                entity.array_period = period
                entity.array['period'] = period

            if processes:
                # build context for this period:
                const_dict = {'period_idx': period_idx+1,
                              'periods': periods,
                              'periodicity': time_period[self.time_scale]*(1 - 2*(self.retro)),
                              'format_date': self.time_scale,
                              'nan': float('nan'),
                              '__globals__': globals_data}
                assert(periods[period_idx+1] == period)

                num_processes = len(processes)
                for p_num, process_def in enumerate(processes, start=1):
                    process, periodicity, start = process_def
                    print("- %d/%d" % (p_num, num_processes), process.name,
                          end=' ')
                    #TODO: provide a custom __str__ method for Process &
                    # Assignment instead
                    if hasattr(process, 'predictor') and process.predictor \
                       and process.predictor != process.name:
                        print("(%s)" % process.predictor, end=' ')
                    print("...", end=' ')
                    # TDOD: change that
                    if isinstance(periodicity, int ):
                        if period_idx % periodicity == 0:
                            elapsed, _ = gettime(process.run_guarded, self,
                                                 const_dict)
                        else:
                            elapsed = 0
                            print("skipped (periodicity)")
                    else:
                        assert (periodicity  in time_period)
                        periodicity_process = time_period[periodicity]
                        periodicity_simul = time_period[self.time_scale]
                        month_idx = period % 100  
                        # first condition, to run a process with start == 12
                        # each year even if year are yyyy01
                        #modify start if periodicity_simul is not month
                        start = int(start/periodicity_simul-0.01)*periodicity_simul + 1
                        
                        if (periodicity_process <= periodicity_simul and self.time_scale != 'year0') or \
                                 month_idx % periodicity_process == start % periodicity_process:
                            const_dict['periodicity'] = periodicity_process*(1 - 2*(self.retro))
                            elapsed, _ = gettime(process.run_guarded, self,
                                                 const_dict)                        
                        else:
                            elapsed = 0
                            print("skipped (periodicity)")
                            
                    process_time[process.name] += elapsed
                    if config.show_timings:
                        print("done (%s elapsed)." % time2str(elapsed))
                    else:
                        print("done.")
                    self.start_console(process.entity, period,
                                       globals_data)

#             pdb.set_trace()
            #self.entities[2].table
            print("- storing period data")
            for entity in entities:
                print("  *", entity.name, "...", end=' ')
                timed(entity.store_period_data, period)
                print("    -> %d individuals" % len(entity.array))
#            print " - compressing period data"
#            for entity in entities:
#                print "  *", entity.name, "...",
#                for level in range(1, 10, 2):
#                    print "   %d:" % level,
#                    timed(entity.compress_period_data, level)
            period_objects[period] = sum(len(entity.array)
                                         for entity in entities)
コード例 #31
0
ファイル: data.py プロジェクト: AlexisEidelman/liam2_mahdi
    def run(self, globals_def, entities, start_period):
        input_file, dataset = index_tables(globals_def, entities,
                                           self.input_path)
        output_file = tables.open_file(self.output_path, mode="w")

        try:
            globals_node = getattr(input_file.root, 'globals', None)
            if globals_node is not None:
                output_globals = output_file.create_group("/", "globals",
                                                          "Globals")
                # index_tables already checks whether all tables exist and
                # are coherent with globals_def
                for name in globals_def:
                    if name in globals_node:
                        #noinspection PyProtectedMember
                        getattr(globals_node, name)._f_copy(output_globals)

            entities_tables = dataset['entities']
            output_entities = output_file.create_group("/", "entities",
                                                       "Entities")
            output_indexes = output_file.create_group("/", "indexes", "Indexes")
            print(" * copying tables")
            for ent_name, entity in entities.iteritems():
                print(ent_name, "...")

                # main table

                table = entities_tables[ent_name]

                index_node = output_file.create_group("/indexes", ent_name)
                entity.output_index_node = index_node
                entity.input_index = table.id2rownum_per_period
                entity.input_rows = table.period_index
                entity.input_table = table.table
                entity.base_period = table.base_period

# this is what should happen
#                entity.indexed_input_table = entities_tables[ent_name]
#                entity.indexed_output_table = entities_tables[ent_name]

                #TODO: copying the table and generally preparing the output
                # file should be a different method than indexing
                print(" * copying table...")
                start_time = time.time()
                input_rows = entity.input_rows
                output_rows = dict((p, rows)
                                   for p, rows in input_rows.iteritems()
                                   if p < start_period)
                if output_rows:
                    # stoprow = last row of the last period before start_period
                    _, stoprow = input_rows[max(output_rows.iterkeys())]
                else:
                    stoprow = 0

                output_table = copy_table(table.table, output_entities,
                                          entity.fields, stop=stoprow,
                                          show_progress=True)
                entity.output_rows = output_rows
                print("done (%s elapsed)." % time2str(time.time() - start_time))

                print(" * building array for first simulated period...",
                      end=' ')
                start_time = time.time()

                #TODO: this whole process of merging all periods is very
                # opinionated and does not allow individuals to die/disappear
                # before the simulation starts. We couldn't for example,
                # take the output of one of our simulation and
                # re-simulate only some years in the middle, because the dead
                # would be brought back to life. In conclusion, it should be
                # optional.
                entity.array, entity.id_to_rownum = \
                    build_period_array(table.table, entity.fields,
                                       entity.input_rows,
                                       entity.input_index, start_period)
                assert isinstance(entity.array, ColumnArray)
                entity.array_period = start_period
                print("done (%s elapsed)." % time2str(time.time() - start_time))
                entity.table = output_table
        except:
            input_file.close()
            output_file.close()
            raise

        return input_file, output_file, dataset['globals']
コード例 #32
0
ファイル: simulation.py プロジェクト: fagan2888/liam2
    def run_single(self, run_console=False, run_num=None):
        start_time = time.time()

        input_dataset = timed(self.data_source.load,
                              self.globals_def,
                              self.entities_map)

        globals_data = input_dataset.get('globals')
        timed(self.data_sink.prepare, self.globals_def, self.entities_map,
              input_dataset, self.start_period - 1)

        print(" * building arrays for first simulated period")
        for ent_name, entity in self.entities_map.iteritems():
            print("    -", ent_name, "...", end=' ')
            # TODO: this whole process of merging all periods is very
            # opinionated and does not allow individuals to die/disappear
            # before the simulation starts. We couldn't for example,
            # take the output of one of our simulation and
            # re-simulate only some years in the middle, because the dead
            # would be brought back to life. In conclusion, it should be
            # optional.
            timed(entity.build_period_array, self.start_period - 1)
        print("done.")

        if config.autodump or config.autodiff:
            if config.autodump:
                fname, _ = config.autodump
                mode = 'w'
            else:  # config.autodiff
                fname, _ = config.autodiff
                mode = 'r'
            fpath = os.path.join(config.output_directory, fname)
            h5_autodump = tables.open_file(fpath, mode=mode)
            config.autodump_file = h5_autodump
        else:
            h5_autodump = None

        # tell numpy we do not want warnings for x/0 and 0/0
        np.seterr(divide='ignore', invalid='ignore')

        process_time = defaultdict(float)
        period_objects = {}
        eval_ctx = EvaluationContext(self, self.entities_map, globals_data)

        def simulate_period(period_idx, period, processes, entities,
                            init=False):
            period_start_time = time.time()

            # set current period
            eval_ctx.period = period

            if config.log_level in ("functions", "processes"):
                print()
            print("period", period,
                  end=" " if config.log_level == "periods" else "\n")
            if init and config.log_level in ("functions", "processes"):
                for entity in entities:
                    print("  * %s: %d individuals" % (entity.name,
                                                      len(entity.array)))
            else:
                if config.log_level in ("functions", "processes"):
                    print("- loading input data")
                    for entity in entities:
                        print("  *", entity.name, "...", end=' ')
                        timed(entity.load_period_data, period)
                        print("    -> %d individuals" % len(entity.array))
                else:
                    for entity in entities:
                        entity.load_period_data(period)
            for entity in entities:
                entity.array_period = period
                entity.array['period'] = period

            if processes:
                num_processes = len(processes)
                for p_num, process_def in enumerate(processes, start=1):
                    process, periodicity = process_def

                    # set current entity
                    eval_ctx.entity_name = process.entity.name

                    if config.log_level in ("functions", "processes"):
                        print("- %d/%d" % (p_num, num_processes), process.name,
                              end=' ')
                        print("...", end=' ')
                    if period_idx % periodicity == 0:
                        elapsed, _ = gettime(process.run_guarded, eval_ctx)
                    else:
                        elapsed = 0
                        if config.log_level in ("functions", "processes"):
                            print("skipped (periodicity)")

                    process_time[process.name] += elapsed
                    if config.log_level in ("functions", "processes"):
                        if config.show_timings:
                            print("done (%s elapsed)." % time2str(elapsed))
                        else:
                            print("done.")
                    self.start_console(eval_ctx)

            if config.log_level in ("functions", "processes"):
                print("- storing period data")
                for entity in entities:
                    print("  *", entity.name, "...", end=' ')
                    timed(entity.store_period_data, period)
                    print("    -> %d individuals" % len(entity.array))
            else:
                for entity in entities:
                    entity.store_period_data(period)
#            print " - compressing period data"
#            for entity in entities:
#                print "  *", entity.name, "...",
#                for level in range(1, 10, 2):
#                    print "   %d:" % level,
#                    timed(entity.compress_period_data, level)
            period_objects[period] = sum(len(entity.array)
                                         for entity in entities)
            period_elapsed_time = time.time() - period_start_time
            if config.log_level in ("functions", "processes"):
                print("period %d" % period, end=' ')
            print("done", end=' ')
            if config.show_timings:
                print("(%s elapsed)" % time2str(period_elapsed_time), end="")
                if init:
                    print(".")
                else:
                    main_elapsed_time = time.time() - main_start_time
                    periods_done = period_idx + 1
                    remaining_periods = self.periods - periods_done
                    avg_time = main_elapsed_time / periods_done
                    # future_time = period_elapsed_time * 0.4 + avg_time * 0.6
                    remaining_time = avg_time * remaining_periods
                    print(" - estimated remaining time: %s."
                          % time2str(remaining_time))
            else:
                print()

        print("""
=====================
 starting simulation
=====================""")
        try:
            simulate_period(0, self.start_period - 1, self.init_processes,
                            self.entities, init=True)
            main_start_time = time.time()
            periods = range(self.start_period,
                            self.start_period + self.periods)
            for period_idx, period in enumerate(periods):
                simulate_period(period_idx, period,
                                self.processes, self.entities)

            total_objects = sum(period_objects[period] for period in periods)
            avg_objects = str(total_objects // self.periods) \
                if self.periods else 'N/A'
            main_elapsed_time = time.time() - main_start_time
            ind_per_sec = str(int(total_objects / main_elapsed_time)) \
                if main_elapsed_time else 'inf'

            print("""
==========================================
 simulation done
==========================================
 * %s elapsed
 * %s individuals on average
 * %s individuals/s/period on average
==========================================
""" % (time2str(time.time() - start_time), avg_objects, ind_per_sec))

            show_top_processes(process_time, 10)
#            if config.debug:
#                show_top_expr()

            if run_console:
                ent_name = self.default_entity
                if ent_name is None and len(eval_ctx.entities) == 1:
                    ent_name = eval_ctx.entities.keys()[0]
                # FIXME: fresh_data prevents the old (cloned) EvaluationContext
                # to be referenced from each EntityContext, which lead to period
                # being fixed to the last period of the simulation. This should
                # be fixed in EvaluationContext.copy but the proper fix breaks
                # stuff (see the comments there)
                console_ctx = eval_ctx.clone(fresh_data=True,
                                             entity_name=ent_name)
                c = console.Console(console_ctx)
                c.run()

        finally:
            self.close()
            if h5_autodump is not None:
                h5_autodump.close()
            if self.minimal_output:
                output_path = self.data_sink.output_path
                dirname = os.path.dirname(output_path)
                try:
                    os.remove(output_path)
                    os.rmdir(dirname)
                except OSError:
                    print("WARNING: could not delete temporary directory: %r"
                          % dirname)
コード例 #33
0
def main_train(args):
    # 获取命令参数
    if args.resume_training is not None:
        if not os.path.isfile(args.resume_training):
            print(f"{args.resume_training} 不是一个合法的文件!")
            return
        else:
            print(f"加载检查点:{args.resume_training}")
    cuda = args.cuda
    resume = args.resume_training
    batch_size = args.batch_size
    milestones = args.milestones
    lr = args.lr
    total_epoch = args.epochs
    resume_checkpoint_filename = args.resume_training
    best_model_name = args.best_model_name
    checkpoint_name = args.best_model_name
    data_path = args.data_path
    start_epoch = 1

    print("加载数据....")
    dataset = ISONetData(data_path=data_path)
    dataset_test = ISONetData(data_path=data_path, train=False)
    data_loader = DataLoader(dataset=dataset,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=6,
                             pin_memory=True)
    data_loader_test = DataLoader(dataset=dataset_test,
                                  batch_size=batch_size,
                                  shuffle=False)
    print("成功加载数据...")
    print(f"训练集数量: {len(dataset)}")
    print(f"验证集数量: {len(dataset_test)}")

    model_path = Path("models")
    checkpoint_path = model_path.joinpath("checkpoint")

    if not model_path.exists():
        model_path.mkdir()
    if not checkpoint_path.exists():
        checkpoint_path.mkdir()

    if torch.cuda.is_available():
        device = torch.cuda.current_device()
    else:
        print("cuda 无效!")
        cuda = False

    net = ISONet()
    criterion = nn.MSELoss(reduction="mean")
    optimizer = optim.Adam(net.parameters(), lr=lr)

    if cuda:
        net = net.to(device=device)
        criterion = criterion.to(device=device)

    scheduler = MultiStepLR(optimizer=optimizer,
                            milestones=milestones,
                            gamma=0.1)
    writer = SummaryWriter()

    # 恢复训练
    if resume:
        print("恢复训练中...")
        checkpoint = torch.load(
            checkpoint_path.joinpath(resume_checkpoint_filename))
        net.load_state_dict(checkpoint["net"])
        optimizer.load_state_dict((checkpoint["optimizer"]))
        scheduler.load_state_dict(checkpoint["scheduler"])
        resume_epoch = checkpoint["epoch"]
        best_test_loss = checkpoint["best_test_loss"]

        start_epoch = resume_epoch + 1
        print(f"从第[{start_epoch}]轮开始训练...")
        print(f"上一次的损失为: [{best_test_loss}]...")
    else:
        # 初始化权重
        for m in net.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    if not locals().get("best_test_loss"):
        best_test_loss = 0

    record = 0
    for epoch in range(start_epoch, total_epoch):
        print(f"开始第 [{epoch}] 轮训练...")
        net.train()
        writer.add_scalar("Train/Learning Rate",
                          scheduler.get_last_lr()[0], epoch)
        for i, (data, label) in enumerate(data_loader, 0):
            if i == 0:
                start_time = int(time.time())
            if cuda:
                data = data.to(device=device)
                label = label.to(device=device)
            label = label.unsqueeze(1)

            optimizer.zero_grad()

            output = net(data)

            loss = criterion(output, label)

            loss.backward()

            optimizer.step()
            if i % 500 == 499:
                end_time = int(time.time())
                use_time = end_time - start_time

                print(
                    f">>> epoch[{epoch}] loss[{loss:.4f}]  {i * batch_size}/{len(dataset)} lr{scheduler.get_last_lr()} ",
                    end="")
                left_time = ((len(dataset) - i * batch_size) / 500 /
                             batch_size) * (end_time - start_time)
                print(
                    f"耗费时间:[{end_time - start_time:.2f}]秒,估计剩余时间: [{left_time:.2f}]秒"
                )
                start_time = end_time
            # 记录到 tensorboard
            if i % 128 == 127:
                writer.add_scalar("Train/loss", loss, record)
                record += 1

        # validate
        print("测试模型...")
        net.eval()

        test_loss = 0
        with torch.no_grad():
            loss_t = nn.MSELoss(reduction="mean")
            if cuda:
                loss_t = loss_t.to(device)
            for data, label in data_loader_test:
                if cuda:
                    data = data.to(device)
                    label = label.to(device)
                # expand dim
                label = label.unsqueeze_(1)
                predict = net(data)
                # sum up batch loss
                test_loss += loss_t(predict, label).item()

        test_loss /= len(dataset_test)
        test_loss *= batch_size
        print(
            f'\nTest Data: Average batch[{batch_size}] loss: {test_loss:.4f}\n'
        )
        scheduler.step()

        writer.add_scalar("Test/Loss", test_loss, epoch)

        checkpoint = {
            "net": net.state_dict(),
            "optimizer": optimizer.state_dict(),
            "epoch": epoch,
            "scheduler": scheduler.state_dict(),
            "best_test_loss": best_test_loss
        }

        if best_test_loss == 0:
            print("保存模型中...")
            torch.save(net.state_dict(), model_path.joinpath(best_model_name))
            best_test_loss = test_loss
        else:
            # 保存更好的模型
            if test_loss < best_test_loss:
                print("获取到更好的模型,保存中...")
                torch.save(net.state_dict(),
                           model_path.joinpath(best_model_name))
                best_test_loss = test_loss
        # 保存检查点
        if epoch % args.save_every_epochs == 0:
            c_time = time2str()
            torch.save(
                checkpoint,
                checkpoint_path.joinpath(
                    f"{checkpoint_name}_{epoch}_{c_time}.cpth"))
            print(f"保存检查点: [{checkpoint_name}_{epoch}_{c_time}.cpth]...\n")
コード例 #34
0
ファイル: simulation.py プロジェクト: fagan2888/liam2
        def simulate_period(period_idx, period, processes, entities,
                            init=False):
            period_start_time = time.time()

            # set current period
            eval_ctx.period = period

            if config.log_level in ("functions", "processes"):
                print()
            print("period", period,
                  end=" " if config.log_level == "periods" else "\n")
            if init and config.log_level in ("functions", "processes"):
                for entity in entities:
                    print("  * %s: %d individuals" % (entity.name,
                                                      len(entity.array)))
            else:
                if config.log_level in ("functions", "processes"):
                    print("- loading input data")
                    for entity in entities:
                        print("  *", entity.name, "...", end=' ')
                        timed(entity.load_period_data, period)
                        print("    -> %d individuals" % len(entity.array))
                else:
                    for entity in entities:
                        entity.load_period_data(period)
            for entity in entities:
                entity.array_period = period
                entity.array['period'] = period

            if processes:
                num_processes = len(processes)
                for p_num, process_def in enumerate(processes, start=1):
                    process, periodicity = process_def

                    # set current entity
                    eval_ctx.entity_name = process.entity.name

                    if config.log_level in ("functions", "processes"):
                        print("- %d/%d" % (p_num, num_processes), process.name,
                              end=' ')
                        print("...", end=' ')
                    if period_idx % periodicity == 0:
                        elapsed, _ = gettime(process.run_guarded, eval_ctx)
                    else:
                        elapsed = 0
                        if config.log_level in ("functions", "processes"):
                            print("skipped (periodicity)")

                    process_time[process.name] += elapsed
                    if config.log_level in ("functions", "processes"):
                        if config.show_timings:
                            print("done (%s elapsed)." % time2str(elapsed))
                        else:
                            print("done.")
                    self.start_console(eval_ctx)

            if config.log_level in ("functions", "processes"):
                print("- storing period data")
                for entity in entities:
                    print("  *", entity.name, "...", end=' ')
                    timed(entity.store_period_data, period)
                    print("    -> %d individuals" % len(entity.array))
            else:
                for entity in entities:
                    entity.store_period_data(period)
#            print " - compressing period data"
#            for entity in entities:
#                print "  *", entity.name, "...",
#                for level in range(1, 10, 2):
#                    print "   %d:" % level,
#                    timed(entity.compress_period_data, level)
            period_objects[period] = sum(len(entity.array)
                                         for entity in entities)
            period_elapsed_time = time.time() - period_start_time
            if config.log_level in ("functions", "processes"):
                print("period %d" % period, end=' ')
            print("done", end=' ')
            if config.show_timings:
                print("(%s elapsed)" % time2str(period_elapsed_time), end="")
                if init:
                    print(".")
                else:
                    main_elapsed_time = time.time() - main_start_time
                    periods_done = period_idx + 1
                    remaining_periods = self.periods - periods_done
                    avg_time = main_elapsed_time / periods_done
                    # future_time = period_elapsed_time * 0.4 + avg_time * 0.6
                    remaining_time = avg_time * remaining_periods
                    print(" - estimated remaining time: %s."
                          % time2str(remaining_time))
            else:
                print()
コード例 #35
0
ファイル: simulation.py プロジェクト: AnneDy/Til-Liam
    def run(self, run_console=False):
        start_time = time.time()
        
        h5in, h5out, globals_data = timed(self.data_source.run,
                                          self.globals_def,
                                          entity_registry,
                                          self.init_period)
        
        if config.autodump or config.autodiff:
            if config.autodump:
                fname, _ = config.autodump
                mode = 'w'
            else:  # config.autodiff
                fname, _ = config.autodiff
                mode = 'r'
            fpath = os.path.join(config.output_directory, fname)
            h5_autodump = tables.openFile(fpath, mode=mode)
            config.autodump_file = h5_autodump
        else:
            h5_autodump = None
            
#        input_dataset = self.data_source.run(self.globals_def,
#                                             entity_registry)
#        output_dataset = self.data_sink.prepare(self.globals_def,
#                                                entity_registry)
#        output_dataset.copy(input_dataset, self.init_period - 1)
#        for entity in input_dataset:
#            indexed_array = buildArrayForPeriod(entity)

        # tell numpy we do not want warnings for x/0 and 0/0
        np.seterr(divide='ignore', invalid='ignore')

        process_time = defaultdict(float)
        period_objects = {}
        
        def simulate_period(period_idx, period, periods, processes, entities,
                            init=False):
            print("\nperiod", period)
            if init:
                for entity in entities:
                    print("  * %s: %d individuals" % (entity.name,
                                                      len(entity.array)))
            else:
                print("- loading input data")
                for entity in entities:
                    print("  *", entity.name, "...", end=' ')
                    timed(entity.load_period_data, period)
                    print("    -> %d individuals" % len(entity.array))
            for entity in entities:
                entity.array_period = period
                entity.array['period'] = period

            if processes:
                # build context for this period:
                const_dict = {'period_idx': period_idx+1,
                              'periods': periods,
                              'periodicity': time_period[self.time_scale]*(1 - 2*(self.retro)),
                              'format_date': self.time_scale,
                              'nan': float('nan'),
                              '__globals__': globals_data}
                assert(periods[period_idx+1] == period)

                num_processes = len(processes)
                for p_num, process_def in enumerate(processes, start=1):
                    process, periodicity, start = process_def
                    print("- %d/%d" % (p_num, num_processes), process.name,
                          end=' ')
                    #TODO: provide a custom __str__ method for Process &
                    # Assignment instead
                    if hasattr(process, 'predictor') and process.predictor \
                       and process.predictor != process.name:
                        print("(%s)" % process.predictor, end=' ')
                    print("...", end=' ')
                    # TDOD: change that
                    if isinstance(periodicity, int ):
                        if period_idx % periodicity == 0:
                            elapsed, _ = gettime(process.run_guarded, self,
                                                 const_dict)
                        else:
                            elapsed = 0
                            print("skipped (periodicity)")
                    else:
                        assert (periodicity  in time_period)
                        periodicity_process = time_period[periodicity]
                        periodicity_simul = time_period[self.time_scale]
                        month_idx = period % 100  
                        # first condition, to run a process with start == 12
                        # each year even if year are yyyy01
                        #modify start if periodicity_simul is not month
                        start = int(start/periodicity_simul-0.01)*periodicity_simul + 1
                        
                        if (periodicity_process <= periodicity_simul and self.time_scale != 'year0') or \
                                 month_idx % periodicity_process == start % periodicity_process:
                            const_dict['periodicity'] = periodicity_process*(1 - 2*(self.retro))
                            elapsed, _ = gettime(process.run_guarded, self,
                                                 const_dict)                        
                        else:
                            elapsed = 0
                            print("skipped (periodicity)")
                            
                    process_time[process.name] += elapsed
                    if config.show_timings:
                        print("done (%s elapsed)." % time2str(elapsed))
                    else:
                        print("done.")
                    self.start_console(process.entity, period,
                                       globals_data)

#             pdb.set_trace()
            #self.entities[2].table
            print("- storing period data")
            for entity in entities:
                print("  *", entity.name, "...", end=' ')
                timed(entity.store_period_data, period)
                print("    -> %d individuals" % len(entity.array))
#            print " - compressing period data"
#            for entity in entities:
#                print "  *", entity.name, "...",
#                for level in range(1, 10, 2):
#                    print "   %d:" % level,
#                    timed(entity.compress_period_data, level)
            period_objects[period] = sum(len(entity.array)
                                         for entity in entities)
            
            
            

        try:
            assert(self.time_scale in time_period)
            month_periodicity = time_period[self.time_scale]
            time_direction = 1 - 2*(self.retro)
            time_step = month_periodicity*time_direction
            
            periods = [ self.init_period + int(t/12)*100 + t%12   
                        for t in range(0, (self.periods+1)*time_step, time_step)]
            if self.time_scale == 'year0':
                periods = [ self.init_period + t for t in range(0, (self.periods+1))]
            print("simulated period are going to be: ",periods)
            
            init_start_time = time.time()
            simulate_period(0, self.init_period, [None,periods[0]], self.init_processes,
                            self.entities, init=True)
            
            time_init = time.time() - init_start_time
            main_start_time = time.time()
        
            for period_idx, period in enumerate(periods[1:]):
                period_start_time = time.time()
                simulate_period(period_idx, period, periods,
                                self.processes, self.entities)

#                 if self.legislation:                
#                     if not self.legislation['ex_post']:
#           
#                         elapsed, _ = gettime(liam2of.main,period)
#                         process_time['liam2of'] += elapsed
#                         elapsed, _ = gettime(of_on_liam.main,self.legislation['annee'],[period])
#                         process_time['legislation'] += elapsed
#                         elapsed, _ = gettime(merge_leg.merge_h5,self.data_source.output_path,
#                                              "C:/Til/output/"+"simul_leg.h5",period)                            
#                         process_time['merge_leg'] += elapsed

                time_elapsed = time.time() - period_start_time
                print("period %d done" % period, end=' ')
                if config.show_timings:
                    print("(%s elapsed)." % time2str(time_elapsed))
                else:
                    print()
                    
            total_objects = sum(period_objects[period] for period in periods)
 
#             if self.legislation:           
#                 if self.legislation['ex_post']:
#                        
#                     elapsed, _ = gettime(liam2of.main)
#                     process_time['liam2of'] += elapsed
#                     elapsed, _ = gettime(of_on_liam.main,self.legislation['annee'])
#                     process_time['legislation'] += elapsed
#                     # TODO: faire un programme a part, so far ca ne marche pas pour l'ensemble
#                     # adapter n'est pas si facile, comme on veut economiser une table, 
#                     # on ne peut pas faire de append directement parce qu on met 2010 apres 2011
#                     # a un moment dans le calcul
#                     elapsed, _ = gettime(merge_leg.merge_h5,self.data_source.output_path,
#                                          "C:/Til/output/"+"simul_leg.h5",None)                            
#                     process_time['merge_leg'] += elapsed


            if self.final_stat:
                elapsed, _ = gettime(stat,period)
                process_time['Stat'] += elapsed              

            total_time = time.time() - main_start_time
            time_year = 0
            if len(periods)>1:
                nb_year_approx = periods[-1]/100 - periods[1]/100
                if nb_year_approx > 0 : 
                    time_year = total_time/nb_year_approx
                
            print ("""
==========================================
 simulation done
==========================================
 * %s elapsed
 * %d individuals on average
 * %d individuals/s/period on average
 * %s second for init_process
 * %s time/period in average
 * %s time/year in average
==========================================
""" % (time2str(time.time() - start_time),
       total_objects / self.periods,
       total_objects / total_time,
        time2str(time_init),
        time2str(total_time / self.periods),
        time2str(time_year)))
            
            show_top_processes(process_time, 10)
#            if config.debug:
#                show_top_expr()

            if run_console:
                c = console.Console(self.console_entity, periods[-1],
                                    self.globals_def, globals_data)
                c.run()

        finally:
            if h5in is not None:
                h5in.close()
            h5out.close()
            if h5_autodump is not None:
                h5_autodump.close()
コード例 #36
0
ファイル: main.py プロジェクト: Agent-INF/CycleGAN
    def train(self):
        """Training Function."""
        # Load Dataset from the dataset folder
        self.inputs = utils.load_data(self._dataset_name,
                                      utils.SIZE_BEFORE_CROP, self._mode, True,
                                      self._flipping)

        # Build the network
        self.model_setup()

        # Loss function calculations
        self.compute_losses()

        # Initializing the global variables
        init = (tf.global_variables_initializer(),
                tf.local_variables_initializer())
        saver = tf.train.Saver()

        max_image_num = utils.get_data_size(self._dataset_name, self._mode)

        with tf.Session() as sess:
            sess.run(init)

            # Restore the model to run the model from last checkpoint
            if self._restore:
                ckpt_fname = tf.train.latest_checkpoint(self._weight_dir)
                saver.restore(sess, ckpt_fname)

            if not os.path.exists(self._log_dir):
                os.makedirs(self._log_dir)

            writer = tf.summary.FileWriter(self._log_dir)

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)

            # global_step = tf.contrib.slim.get_or_create_global_step()

            # Training Loop
            for epoch in range(sess.run(self.global_step), self._max_step):
                print("In the epoch ", epoch)
                start_time = time()
                saver.save(sess, self._ckpt_dir, global_step=epoch)

                # Dealing with the learning rate as per the epoch number
                if epoch < 100:
                    curr_lr = self._base_lr
                else:
                    curr_lr = self._base_lr - self._base_lr * (epoch -
                                                               100) / 100

                self.save_images(sess, epoch)

                for batch in range(0, max_image_num):

                    inputs = sess.run(self.inputs)

                    # Optimizing the G_A network
                    _, fake_b_temp, ga_loss, summary_str = sess.run(
                        [
                            self.gen_a_trainer, self.fake_images_b,
                            self.gen_a_loss, self.gen_a_loss_summ
                        ],
                        feed_dict={
                            self.input_a: inputs['images_i'],
                            self.input_b: inputs['images_j'],
                            self.learning_rate: curr_lr
                        })
                    writer.add_summary(summary_str,
                                       epoch * max_image_num + batch)

                    fake_b_from_pool = self.fake_image_pool(
                        self.num_fake_inputs, fake_b_temp, self.fake_b_pool)

                    # Optimizing the D_B network
                    _, db_loss, summary_str = sess.run(
                        [
                            self.dis_b_trainer, self.dis_b_loss,
                            self.dis_b_loss_summ
                        ],
                        feed_dict={
                            self.input_a: inputs['images_i'],
                            self.input_b: inputs['images_j'],
                            self.learning_rate: curr_lr,
                            self.fake_pool_b: fake_b_from_pool
                        })
                    writer.add_summary(summary_str,
                                       epoch * max_image_num + batch)

                    # Optimizing the G_B network
                    _, fake_a_temp, gb_loss, summary_str = sess.run(
                        [
                            self.gen_b_trainer, self.fake_images_a,
                            self.gen_b_loss, self.gen_b_loss_summ
                        ],
                        feed_dict={
                            self.input_a: inputs['images_i'],
                            self.input_b: inputs['images_j'],
                            self.learning_rate: curr_lr
                        })
                    writer.add_summary(summary_str,
                                       epoch * max_image_num + batch)

                    fake_a_from_pool = self.fake_image_pool(
                        self.num_fake_inputs, fake_a_temp, self.fake_a_pool)

                    # Optimizing the D_A network
                    _, da_loss, summary_str = sess.run(
                        [
                            self.dis_a_trainer, self.dis_a_loss,
                            self.dis_a_loss_summ
                        ],
                        feed_dict={
                            self.input_a: inputs['images_i'],
                            self.input_b: inputs['images_j'],
                            self.learning_rate: curr_lr,
                            self.fake_pool_a: fake_a_from_pool
                        })
                    writer.add_summary(summary_str,
                                       epoch * max_image_num + batch)

                    writer.flush()
                    self.num_fake_inputs += 1
                    used = time() - start_time
                    eta = used * max_image_num / (batch + 1)
                    print(
                        'Epoch: %3d, Batch %4d/%d, GA:%.6f, DA:%.6f GB:%.6f, DB:%.6f, Time:%s/%s'
                        % (epoch, batch, max_image_num, ga_loss, da_loss,
                           gb_loss, db_loss, utils.time2str(used),
                           utils.time2str(eta)))

                sess.run(tf.assign(self.global_step, epoch + 1))

            coord.request_stop()
            coord.join(threads)
            writer.add_graph(sess.graph)