import pandas as pd
import unittest
from unittest.mock import patch, MagicMock, Mock
from mage_ai.data_preparation.models.pipeline import Pipeline
class ExampleIOTest(unittest.TestCase):
def setUp(self):
self.pipeline = Pipeline.get('load_data_to_postgres')
def tearDown(self):
"""
Optionally clean up the output of the pipeline/block execution.
Make sure to only do this if you are ok losing all variable data for
the project.
"""
def test_pipeline_execution(self):
bq_mock = MagicMock()
bq_mock.__enter__ = Mock(return_value=bq_mock)
bq_mock.load.return_value = pd.DataFrame(
data=dict(
a=[1, 2, 3],
b=[4, 5, 6],
c=[7, 8, 9],
)
)
postgres_mock = MagicMock()
postgres_mock.__enter__ = Mock(return_value=postgres_mock)
with patch(
'mage_ai.io.postgres.Postgres.with_config', return_value=postgres_mock
), patch('mage_ai.io.bigquery.BigQuery.with_config', return_value=bq_mock):
self.pipeline.execute_sync(global_vars=dict(num_rows=10))
bq_mock.load.assert_called_once()
bq_mock.load.assert_called_with('SELECT * FROM xyz')
postgres_mock.export.assert_called_once()
args, kwargs = postgres_mock.export.call_args
df = args[0]
schema_name = args[1]
table_name = args[2]
self.assertEqual(df.shape, (3, 3))
self.assertEqual(schema_name, 'public')
self.assertEqual(table_name, 'test_table_name')
self.assertEqual(kwargs['index'], False)
self.assertEqual(kwargs['if_exists'], 'replace')
def test_block_execution(self):
block1 = self.pipeline.get_block('primal_dawn')
block2 = self.pipeline.get_block('little_glitter')
bq_mock = MagicMock()
bq_mock.load.return_value = pd.DataFrame(
data=dict(
a=[1, 2, 3],
b=[4, 5, 6],
c=[7, 8, 9],
)
)
with patch('mage_ai.io.bigquery.BigQuery.with_config', return_value=bq_mock):
block1.execute_sync()
bq_mock.load.assert_called_once()
bq_mock.load.assert_called_with('SELECT * FROM xyz')
postgres_mock = MagicMock()
postgres_mock.__enter__ = Mock(return_value=postgres_mock)
with patch(
'mage_ai.io.postgres.Postgres.with_config', return_value=postgres_mock
):
block2.execute_sync()
postgres_mock.export.assert_called_once()
args, kwargs = postgres_mock.export.call_args
df = args[0]
schema_name = args[1]
table_name = args[2]
self.assertEqual(df.shape, (3, 3))
self.assertEqual(schema_name, 'public')
self.assertEqual(table_name, 'test_table_name')
self.assertEqual(kwargs['index'], False)
self.assertEqual(kwargs['if_exists'], 'replace')