예제 #1
0
 def test_no_u(self, tmpdir):
     sys.argv = [
         'train',
         '--output_folder', str(tmpdir),
     ]
     args = parse_args()
     assert len(args.u_connect) == 0
예제 #2
0
 def test_one_to_many(self, tmpdir):
     key1 = '3d_cont'
     key2 = '3d_cat'
     sys.argv = [
         'train',
         '--output_folder', str(tmpdir),
         '--input_tensors', key1, key2,
         '--output_tensors', key1, key2,
         '--u_connect', key1, key1,
         '--u_connect', key1, key2,
     ]
     args = parse_args()
     assert len(args.u_connect) == 1
     assert args.u_connect[MOCK_TMAPS[key1]] == {MOCK_TMAPS[key1], MOCK_TMAPS[key2]}
예제 #3
0
 def test_many_to_one(self, tmpdir):
     inp_key1 = '3d_cont'
     inp_key2 = '3d_cat'
     sys.argv = [
         'train',
         '--output_folder', str(tmpdir),
         '--input_tensors', inp_key1, inp_key2,
         '--output_tensors', inp_key1,
         '--u_connect', inp_key1, inp_key1,
         '--u_connect', inp_key2, inp_key1,
     ]
     args = parse_args()
     assert len(args.u_connect) == 2
     assert args.u_connect[MOCK_TMAPS[inp_key1]] == {MOCK_TMAPS[inp_key1]}
     assert args.u_connect[MOCK_TMAPS[inp_key2]] == {MOCK_TMAPS[inp_key1]}
예제 #4
0
 def test_simple_u(self, tmpdir):
     inp_key = '3d_cont'
     sys.argv = [
         'train',
         '--output_folder', str(tmpdir),
         '--input_tensors', inp_key,
         '--output_tensors', inp_key,
         '--u_connect', inp_key, inp_key,
     ]
     args = parse_args()
     assert len(args.u_connect) == 1
     inp, out = list(args.u_connect.items())[0]
     tmap = MOCK_TMAPS[inp_key]
     assert inp == tmap
     assert out == {tmap, }
예제 #5
0
def default_arguments(tmpdir_factory):
    temp_dir = tmpdir_factory.mktemp('data')
    build_hdf5s(temp_dir, MOCK_TMAPS.values(), n=pytest.N_TENSORS)
    hdf5_dir = str(temp_dir)
    inp_key = '3d_cont'
    out_key = '1d_cat'
    sys.argv = [
        '',
        '--output_folder',
        hdf5_dir,
        '--input_tensors',
        inp_key,
        '--output_tensors',
        out_key,
        '--tensors',
        hdf5_dir,
        '--pool_x',
        '1',
        '--pool_y',
        '1',
        '--pool_z',
        '1',
        '--training_steps',
        '2',
        '--test_steps',
        '3',
        '--validation_steps',
        '2',
        '--epochs',
        '2',
        '--num_workers',
        '0',
        '--batch_size',
        '2',
    ]
    args = parse_args()
    return args
예제 #6
0
        ax2.plot(val_loss,
                 label=label,
                 linestyle=linestyles[i % 4],
                 color=color)
        ax2.text(len(val_loss) - 1, val_loss[-1], str(i))
    ax1.axhline(cutoff,
                label=f'Loss display cutoff at {cutoff:.3f}',
                color='k',
                linestyle='--')
    ax1.set_title('Training Loss')
    ax2.axhline(cutoff,
                label=f'Loss display cutoff at {cutoff:.3f}',
                color='k',
                linestyle='--')
    ax2.set_title('Validation Loss')
    ax3.legend(*ax2.get_legend_handles_labels(),
               loc='upper center',
               fontsize='x-small',
               mode='expand',
               ncol=5)
    ax3.axis('off')
    learning_path = os.path.join(figure_path, 'learning_curves' + IMAGE_EXT)
    plt.tight_layout()
    plt.savefig(learning_path)
    logging.info('Saved learning curve plot to: {}'.format(learning_path))


if __name__ == '__main__':
    args = parse_args()
    run(args)  # back to the top