import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from pureml.decorators import dataset
from tensorflow.keras.applications.inception_v3
import InceptionV3
from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense,
GlobalAveragePooling2D, Input
from pureml.decorators import model

@dataset("petdata")
def load_data(img_folder = "PetImages"):
  image_size = (180, 180)
  batch_size = 16
  train_ds,
  val_ds = tf.keras.utils.img_dataset_from_directory(
      img_folder,
      validation_split=0.2,
      subset="both",
      seed=1337,
      image_size=image_size,
      batch_size=batch_size,
  )
  data_augmentation = keras.Sequential(
   [layers.RandomFlip("horizontal"),
   layers.RandomRotation(0.1),]
  )
  train_ds = train_ds.map(
    lambda img, label: (data_augmentation(img), label),
    num_parallel_calls=tf.data.AUTOTUNE,
  )
  train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
  val_ds = val_ds.prefetch(tf.data.AUTOTUNE)
  return train_ds, val_ds

@model("pet_classifier")
def train_model(train_ds, val_ds):
  input_tensor = Input(shape=(180, 180, 3))
  base_model = InceptionV3(
   input_tensor=input_tensor,
   weights='imagenet',
   include_top=False
  )
  x = base_model.output
  x = GlobalAveragePooling2D()(x)
  x = Dense(1024, activation='relu')(x)
  predictions = Dense(1, activation='softmax')(x)
  model_inc = Model(
   inputs=base_model.input,
   outputs=predictions
  )
  model_inc.compile(
   optimizer='rmsprop',
   loss='binary_crossentropy',
   metrics=["accuracy"]
  )
  model_inc.fit(
    train_ds,
    epochs=2,
    validation_data=val_ds,
  )
  return model_inc