Exemplo n.º 1
0
def run_train_epoch(batch_id_start):
    net.reset_states()
    net.train()
    global i, train_loss_sum
    for batch_id in range(batch_id_start, o.train_batch_num):
        o.batch_id = batch_id
        loss = run_batch(batch_id, 'train')
        train_loss_sum = train_loss_sum + loss
        i = i + 1
        if ((0 < o.reset_interval < 1
             and np.random.random_sample() < o.reset_interval)
                or (o.reset_interval > 1
                    and i % round(max(o.reset_interval / o.T, 1)) == 0)):
            print('---------- State Reset ----------')
            net.reset_states()
        if o.print_interval > 0 and i % o.print_interval == 0:
            print('Epoch: %.2f/%d, iter: %d/%d, batch: %d/%d, loss: %.3f' %
                  (i / o.train_batch_num, o.epoch_num, i, iter_num,
                   batch_id + 1, o.train_batch_num, loss))
        if o.train_log_interval > 0 and i % o.train_log_interval == 0:
            benchmark['train_loss'].append(
                (i, train_loss_sum / o.train_log_interval))
            train_loss_sum = 0
        if o.validate_interval > 0 and (i % o.validate_interval == 0
                                        or i == iter_num):
            val_loss = run_test_epoch() if o.test_batch_num > 0 else 0
            net.train()
            benchmark['val_loss'].append((i, val_loss))
            benchmark['i_start'] = i
            savepoint = {'o': vars(o), 'benchmark': benchmark}
            utils.save_json(savepoint, result_file_header + str(i) + '.json')
            savepoint['net_states'] = net.state_dict()
            # savepoint['optim_states'] = optimizer.state_dict()
            torch.save(savepoint, result_file_header + str(i) + '.pt')
        if o.save_interval > 0 and (i % o.save_interval == 0 or i == iter_num):
            benchmark['i_start'] = i
            savepoint = {'o': vars(o), 'benchmark': benchmark}
            utils.save_json(savepoint, result_file_header + 'latest.json')
            savepoint['net_states'] = net.state_dict()
            # savepoint['optim_states'] = optimizer.state_dict()
            torch.save(savepoint, result_file_header + 'latest.pt')
        print('-' * 80)
Exemplo n.º 2
0
def run_test_epoch():
    torch.save(net.states, result_file_header + 'tmp.pt')
    net.reset_states()
    net.eval()
    val_loss_sum = 0
    videos = []
    for batch_id in range(0, o.test_batch_num):
        o.batch_id = batch_id
        if o.v  > 0:
            loss, video = run_batch(batch_id, 'test')
            videos.append(video)
        else:
            loss = run_batch(batch_id, 'test')
        val_loss_sum = val_loss_sum + loss
        print('Validation %d / %d, loss = %.3f'% (batch_id+1, o.test_batch_num, loss))
    val_loss = val_loss_sum / o.test_batch_num
    print('Final validation loss: %.3f'% (val_loss))
    net.states = torch.load(result_file_header + 'tmp.pt')
    if o.v > 0:
        utils.save_json(videos, path.join(o.pic_dir, 'tba_mot_annotations_masks.json'))
    return val_loss
Exemplo n.º 3
0
 def save_vocabs_and_config(self,
                            idx2label_path=None,
                            idx2cls_path=None,
                            config_path=None):
     idx2label_path = if_none(idx2label_path, self.idx2label_path)
     idx2cls_path = if_none(idx2cls_path, self.idx2cls_path)
     config_path = if_none(config_path, self.config_path)
     logging.info("Saving vocabs...")
     save_json(self.idx2label, idx2label_path)
     if self.idx2cls:
         save_json(self.idx2cls, idx2cls_path)
     save_json(self.get_config(), config_path)
Exemplo n.º 4
0
    if not isfile(parameters_path) or overwrite_parameters:
        parameters = {
            "training": {
                "only_static_letters": False,
                "data_augmentation": True,
                "frames_per_video": 16,
                "epochs": 10,
                "lr_features_extractor": 1e-05,
                "lr_classification": 0.001,
                "batch_size": 4,
                "num_workers": 2,
                "pretrained_resnet": True,
                "use_optical_flow": False
            }
        }
        save_json(content=parameters, filepath=parameters_path)
        print(f"Created a parameters file in {parameters_path}")
    else:
        parameters = read_json(filepath=parameters_path)
        print(f"Read parameters from {parameters_path}")
    pprint(parameters)

    print(f"Loading datasets...")
    ds_train, ds_val = load_dataset(samples_path=train_data_path,
                                    frames_per_video=parameters["training"]["frames_per_video"],
                                    only_static_letters=args.only_static_letters), \
                       load_dataset(samples_path=val_data_path,
                                    frames_per_video=parameters["training"]["frames_per_video"],
                                    only_static_letters=args.only_static_letters)
    print(f"\t|train| = {len(ds_train)}\t|val| = {len(ds_val)}")