示例#1
0
    def on_end_epoch(hook_state, state):
        if val_loader is not None:
            if 'best_loss' not in hook_state:
                hook_state['best_loss'] = np.inf
            if 'wait' not in hook_state:
                hook_state['wait'] = 0

        if val_loader is not None:
            model_utils.evaluate(state['model'],
                                 val_loader,
                                 meters['val'],
                                 desc="Epoch {:d} valid".format(
                                     state['epoch']))

            #state['model'].propose_new()
            #model_utils.evaluate(state['model'],
            #                     val_loader,
            #                     meters['val_new'],
            #                     desc="Epoch {:d} valid (new struct)".format(state['epoch']))

        meter_vals = log_utils.extract_meter_values(meters)
        #if meter_vals['val']['loss'] < meter_vals['val_new']['loss']:
        #    state['model'].reject_new()

        print("Epoch {:02d}: {:s}".format(
            state['epoch'], log_utils.render_meter_values(meter_vals)))
        meter_vals['epoch'] = state['epoch']
        with open(trace_file, 'a') as f:
            json.dump(meter_vals, f)
            f.write('\n')

        if val_loader is not None:
            if meter_vals['val']['loss'] < hook_state['best_loss']:
                hook_state['best_loss'] = meter_vals['val']['loss']
                print(
                    "==> best model (loss = {:0.6f}), saving model...".format(
                        hook_state['best_loss']))

                state['model'].cpu()
                torch.save(
                    state['model'].state_dict(),
                    os.path.join(opt['log.exp_dir'],
                                 'best_model_state_dict.pt'))
                if opt['data.cuda']:
                    state['model'].cuda()

                hook_state['wait'] = 0
            else:
                hook_state['wait'] += 1

                if hook_state['wait'] > opt['train.patience']:
                    print("==> patience {:d} exceeded".format(
                        opt['train.patience']))
                    state['stop'] = True
        else:
            state['model'].cpu()
            torch.save(state['model'],
                       os.path.join(opt['log.exp_dir'], 'best_model.pt'))
            if opt['data.cuda']:
                state['model'].cuda()
示例#2
0
    def on_end_epoch(hook_state, state):
        if val_loader is not None:
            if 'best_loss' not in hook_state:
                hook_state['best_loss'] = np.inf
            if 'wait' not in hook_state:
                hook_state['wait'] = 0

        if val_loader is not None:
            state['loader'].mode = 'val'  # added
            model_utils.evaluate(
                state['model'],
                state['loader'],  #val_loader,
                meters['val'],
                desc="Epoch {:d} valid".format(state['epoch']))

        meter_vals = log_utils.extract_meter_values(meters)
        print("Epoch {:02d}: {:s}".format(
            state['epoch'], log_utils.render_meter_values(meter_vals)))
        meter_vals['epoch'] = state['epoch']
        # this prevents cuda bugs
        meter_vals = {
            k: ({kk: float(vv)
                 for kk, vv in v.items()} if isinstance(v, dict) else v)
            for k, v in meter_vals.items()
        }
        with open(trace_file, 'a') as f:
            json.dump(meter_vals, f)
            f.write('\n')

        if val_loader is not None:
            if meter_vals['val']['loss'] < hook_state['best_loss']:
                hook_state['best_loss'] = meter_vals['val']['loss']
                print(
                    "==> best model (loss = {:0.6f}), saving model...".format(
                        hook_state['best_loss']))

                state['model'].cpu()
                # used with inception
                #torch.save(state['model'].encoder.added_layers, os.path.join(opt['log.exp_dir'], 'best_model.t7'))
                torch.save(state['model'],
                           os.path.join(opt['log.exp_dir'], 'best_model.t7'))
                if opt['data.cuda']:
                    state['model'].cuda()

                hook_state['wait'] = 0
            else:
                hook_state['wait'] += 1

                if hook_state['wait'] > opt['train.patience']:
                    print("==> patience {:d} exceeded".format(
                        opt['train.patience']))
                    state['stop'] = True
        else:
            state['model'].cpu()
            # used with inception
            #torch.save(state['model'].encoder.added_layers, os.path.join(opt['log.exp_dir'], 'best_model.t7'))
            torch.save(state['model'],
                       os.path.join(opt['log.exp_dir'], 'best_model.t7'))
            if opt['data.cuda']:
                state['model'].cuda()
示例#3
0
 def on_update(state):
     for field, meter in meters['train'].items():
         meter.add(state['output'][field])
     if state['batch'] % 5 == 0:
         meter_vals = log_utils.extract_meter_values(meters)
         print("Epoch {:02d}: {:s}".format(
             state['epoch'], log_utils.render_meter_values(meter_vals)))
示例#4
0
    def on_end_epoch(hook_state, state):
        if val_loader is not None:
            if 'best_loss' not in hook_state:
                hook_state['best_loss'] = np.inf
            if 'wait' not in hook_state:
                hook_state['wait'] = 0

        if val_loader is not None:
            model_utils.evaluate(state['model'],
                                 val_loader,
                                 meters['val'],
                                 desc="Epoch {:d} valid".format(state['epoch']))

        meter_vals = log_utils.extract_meter_values(meters)
        print("Epoch {:02d}: {:s}".format(state['epoch'], log_utils.render_meter_values(meter_vals)))
        meter_vals['epoch'] = state['epoch']
        with open(trace_file, 'a') as f:
            json.dump(meter_vals, f)
            f.write('\n')
        
        if state['epoch'] == 1 or state['epoch'] % 10 == 0:
            state['model'].cpu()
            torch.save(state['model'], os.path.join(opt['log.exp_dir'], f'epoch_{state["epoch"]}.pt'))
            torch.save(state['optimizer'].state_dict(), 
                        os.path.join(opt['log.exp_dir'], f'epoch_{state["epoch"]}_optim.pt'))
            # save everything including model state into log dir 
            # torch.save(state, os.path.join(opt['log.exp_dir'], f'epoch_{state["epoch"]}.pt.tar'))
            if opt['data.cuda']:
                state['model'].cuda()

        if val_loader is not None:
            if meter_vals['val']['loss'] < hook_state['best_loss']:
                hook_state['best_loss'] = meter_vals['val']['loss']
                print("==> best model (loss = {:0.6f}), saving model...".format(hook_state['best_loss']))

                state['model'].cpu()
                torch.save(state['model'], os.path.join(opt['log.exp_dir'], 'best_model.pt'))
                torch.save(state['optimizer'].state_dict(), os.path.join(opt['log.exp_dir'], 'best_model_optim.pt'))
                # save everything including model state into log dir 
                # torch.save(state, os.path.join(opt['log.exp_dir'], 'best_model.pt.tar'))
                if opt['data.cuda']:
                    state['model'].cuda()

                hook_state['wait'] = 0
            else:
                hook_state['wait'] += 1

                if hook_state['wait'] > opt['train.patience']:
                    print("==> patience {:d} exceeded".format(opt['train.patience']))
                    state['stop'] = True
        else:
            state['model'].cpu()
            torch.save(state['model'], os.path.join(opt['log.exp_dir'], 'best_model.pt'))
            if opt['data.cuda']:
                state['model'].cuda()
示例#5
0
文件: train.py 项目: zhuzhenxi/aitom
    def on_end_epoch(hook_state, state):
        if val_loader is not None:
            if 'best_loss' not in hook_state:
                hook_state['best_loss'] = 0
            if 'wait' not in hook_state:
                hook_state['wait'] = 0

        if val_loader is not None:
            with torch.no_grad():
                model_utils.evaluate(state['model'],
                                     val_loader,
                                     meters['val'],
                                     opt['model.stage'],
                                     desc="Epoch {:d} valid".format(
                                         state['epoch']))

        meter_vals = log_utils.extract_meter_values(meters)
        print("Epoch {:02d}: {:s}".format(
            state['epoch'], log_utils.render_meter_values(meter_vals)))
        meter_vals['epoch'] = state['epoch']
        with open(trace_file, 'a') as f:
            json.dump(meter_vals, f)
            f.write('\n')

        if val_loader is not None:
            if meter_vals['val']['acc'] > hook_state['best_loss']:
                hook_state['best_loss'] = meter_vals['val']['acc']
                print("==> best model (acc = {:0.6f}), saving model...".format(
                    hook_state['best_loss']))

                state['model'].cpu()
                torch.save(state['model'],
                           os.path.join(opt['log.exp_dir'], 'best_model.pt'))
                if opt['data.cuda']:
                    state['model'].cuda()

                hook_state['wait'] = 0
            else:
                hook_state['wait'] += 1

                if hook_state['wait'] > opt['train.patience']:
                    print("==> patience {:d} exceeded".format(
                        opt['train.patience']))
                    state['stop'] = True
        else:
            state['model'].cpu()
            torch.save(state['model'],
                       os.path.join(opt['log.exp_dir'], 'best_model.pt'))
            if opt['data.cuda']:
                state['model'].cuda()
    def on_end_epoch(hook_state, state):
        if val_loader is not None:
            if 'best_loss' not in hook_state:
                hook_state['best_loss'] = np.inf
            if 'wait' not in hook_state:
                hook_state['wait'] = 0

        if val_loader is not None:
            model_utils.evaluate(state['model'],
                                 val_loader,
                                 meters['val'],
                                 desc="Epoch {:d} valid".format(
                                     state['epoch']))

        meter_vals = log_utils.extract_meter_values(meters)
        print("Epoch {:02d}: {:s}".format(
            state['epoch'], log_utils.render_meter_values(meter_vals)))
        meter_vals['epoch'] = state['epoch']
        with open(trace_file, 'a') as f:
            json.dump(meter_vals, f)
            f.write('\n')

        tensorboard.add_scalar("Value MAP", meter_vals['val']['map'],
                               state['epoch'])
        tensorboard.add_scalar("Value Loss", meter_vals['val']['loss'],
                               state['epoch'])
        tensorboard.add_scalar("Value Accuracy", meter_vals['val']['acc'],
                               state['epoch'])
        tensorboard.add_scalar("Train MAP", meter_vals['train']['map'],
                               state['epoch'])
        tensorboard.add_scalar("Train Loss", meter_vals['train']['loss'],
                               state['epoch'])
        tensorboard.add_scalar("Train Accuracy", meter_vals['train']['acc'],
                               state['epoch'])
        tensorboard.add_scalar("Best Loss", hook_state['best_loss'],
                               state['epoch'])

        if val_loader is not None:
            if meter_vals['val']['loss'] < hook_state['best_loss']:
                hook_state['best_loss'] = meter_vals['val']['loss']
                print(
                    "==> best model (loss = {:0.6f}), saving model...".format(
                        hook_state['best_loss']))

                state['model'].cpu()
                torch.save(state['model'],
                           os.path.join(opt['log.exp_dir'], 'best_model.pt'))
                if opt['data.cuda']:
                    state['model'].cuda()

                hook_state['wait'] = 0
            else:
                hook_state['wait'] += 1

                if hook_state['wait'] > opt['train.patience']:
                    print("==> patience {:d} exceeded".format(
                        opt['train.patience']))
                    state['stop'] = True
        else:
            state['model'].cpu()
            torch.save(state['model'],
                       os.path.join(opt['log.exp_dir'], 'best_model.pt'))
            if opt['data.cuda']:
                state['model'].cuda()