// Load the dataset
Table wine = Table.read().csv("wine.csv");

// Part A: Display summary statistics
System.out.println("Summary Statistics:");
System.out.println(wine.summary());

// Part B: Create a histogram of wine quality distribution
Figure qualityHist = Histogram.create("Wine Quality Distribution", wine, "quality");
Plot.show(qualityHist);

// Part B continued: Create a scatter plot of alcohol vs quality
Figure alcoholScatter = ScatterPlot.create("Alcohol vs Quality", wine, "alcohol", "quality");
Plot.show(alcoholScatter);

// Provided: Group wines by quality level
Table qualityGroups = wine.summarize(
    "alcohol", AggregateFunctions.mean,
    "pH", AggregateFunctions.mean,
    "volatile acidity", AggregateFunctions.mean
).by("quality");
System.out.println("\nCharacteristics by quality level:");
System.out.println(qualityGroups); 
// Assume `wine` is your Tablesaw Table already loaded.
Table wine = Table.read().csv("wine.csv");

// Convert Tablesaw table to SMILE DataFrame
String[] colNames = wine.columnNames().toArray(String[]::new);
double[][] data = wine.as().doubleMatrix();
DataFrame df = DataFrame.of(data, colNames);

// Convert quality to IntVector (classification target)
IntVector quality = IntVector.of("quality", df.doubleVector("quality").stream()
    .mapToInt(d -> (int) d)
    .toArray());
df = df.drop("quality").merge(quality);

// Split data into training and test sets (80/20 split)
int n = df.nrows();
int[] indices = IntStream.range(0, n).toArray();
MathEx.permutate(indices);
int splitIndex = (int)(n * 0.8);

DataFrame trainDf = df.slice(0, splitIndex);
DataFrame testDf = df.slice(splitIndex, n);

// Part A: Train a Random Forest model using SMILE
int[] yTrain = trainDf.intVector("quality").toIntArray();
double[][] xTrain = trainDf.drop("quality").toArray();

RandomForest rf = RandomForest.fit(xTrain, yTrain);

// Part B: Calculate and display model accuracy
int[] yTrue = testDf.intVector("quality").toIntArray();
double[][] xTest = testDf.drop("quality").toArray();

int[] yPred = rf.predict(xTest);
double accuracy = Accuracy.of(yTrue, yPred);

System.out.printf("SMILE Random Forest Accuracy: %.2f%%\n", accuracy * 100);
Table wine = Table.read().csv("wine.csv");

// Convert to Weka format
ArrayList<Attribute> attributes = new ArrayList<>();
for (String col : wine.columnNames()) {
    if (!col.equals("quality")) {
        attributes.add(new Attribute(col));
    }
}

IntColumn qualityCol = (IntColumn) wine.intColumn("quality");
int minQuality = (int) qualityCol.min();
int maxQuality = (int) qualityCol.max();
ArrayList<String> qualityVals = new ArrayList<>();
for (int i = minQuality; i <= maxQuality; i++) {
    qualityVals.add(String.valueOf(i));
}
attributes.add(new Attribute("quality", qualityVals));

Instances wData = new Instances("Wine", attributes, wine.rowCount());
wData.setClassIndex(wData.numAttributes() - 1);

for (int i = 0; i < wine.rowCount(); i++) {
    double[] vals = new double[wData.numAttributes()];
    for (int j = 0; j < wine.columnCount() - 1; j++) {
        vals[j] = ((NumberColumn<?, ?>) wine.column(j)).getDouble(i);
    }
    vals[wData.numAttributes() - 1] = qualityVals.indexOf(String.valueOf(qualityCol.get(i)));
    wData.add(new DenseInstance(1.0, vals));
}

// Split data
int trainSize = (int) Math.round(wData.numInstances() * 0.8);
Instances train = new Instances(wData, 0, trainSize);
Instances test = new Instances(wData, trainSize, wData.numInstances() - trainSize);

// Train Weka Random Forest and calculate accuracy
RandomForest wekaRf = new RandomForest();

try {
    wekaRf.buildClassifier(train);

    Evaluation eval = new Evaluation(train);
    eval.evaluateModel(wekaRf, test);

    System.out.printf("Weka Random Forest Accuracy: %.2f%%\n", eval.pctCorrect());

    // Compare models
    System.out.println("\nModel Comparison Complete!");
    System.out.println("Which model performed better? Analyze the results above.");

} catch (Exception e) {
    e.printStackTrace();
}