TrainableModel

abstract class TrainableModel : TensorFlowInferenceModel

Base abstract class for all trainable models.

Constructors

TrainableModel
Link copied to clipboard
fun TrainableModel()

Functions

close
Link copied to clipboard
open override fun close()

Closes internal resources: session and kGraph.

compile
Link copied to clipboard
abstract fun compile(optimizer: Optimizer, loss: LossFunction, metrics: List<Metric>)
abstract fun compile(optimizer: Optimizer, loss: LossFunction, metric: Metric)
abstract fun compile(optimizer: Optimizer, loss: LossFunction, metric: Metrics)
abstract fun compile(optimizer: Optimizer, loss: Losses, metric: Metric)
abstract fun compile(optimizer: Optimizer, loss: Losses, metric: Metrics)

Configures the model for training.

copy
Link copied to clipboard
open override fun copy(copiedModelName: String?, saveOptimizerState: Boolean, copyWeights: Boolean): TensorFlowInferenceModel

Creates a copy.

evaluate
Link copied to clipboard
fun evaluate(dataset: Dataset, metric: Metrics): Double

Evaluates dataset via metric.

abstract fun evaluate(dataset: Dataset, batchSize: Int = 256, callbacks: List<Callback> = listOf()): EvaluationResult
fun evaluate(dataset: Dataset, batchSize: Int = 256, callback: Callback): EvaluationResult

Returns the metrics and loss values for the model in test (evaluation) mode.

fit
Link copied to clipboard
abstract fun fit(dataset: Dataset, epochs: Int = 5, batchSize: Int = 32, callbacks: List<Callback> = listOf()): TrainingHistory
fun fit(dataset: Dataset, epochs: Int = 5, batchSize: Int = 32, callback: Callback): TrainingHistory
abstract fun fit(trainingDataset: Dataset, validationDataset: Dataset, epochs: Int = 5, trainBatchSize: Int = 32, validationBatchSize: Int = 256, callbacks: List<Callback> = listOf()): TrainingHistory
fun fit(trainingDataset: Dataset, validationDataset: Dataset, epochs: Int = 5, trainBatchSize: Int = 32, validationBatchSize: Int = 256, callback: Callback): TrainingHistory

Trains the model for a fixed number of epochs (iterations over a dataset).

fun fit(dataset: OnHeapDataset, validationRate: Double, epochs: Int, trainBatchSize: Int, validationBatchSize: Int, callbacks: List<Callback> = listOf()): TrainingHistory
fun fit(dataset: OnHeapDataset, validationRate: Double, epochs: Int, trainBatchSize: Int, validationBatchSize: Int, callback: Callback): TrainingHistory

Trains the model for a fixed number of epochs (iterations on a dataset).

graphToString
Link copied to clipboard
fun graphToString(): String

Forms the graph description in string format.

input
Link copied to clipboard
fun input(inputOp: Input)

Chain-like setter to set up inputOp.

loadWeights
Link copied to clipboard
open fun loadWeights(modelDirectory: File, loadOptimizerState: Boolean = false)

Loads variable data from .txt files.

output
Link copied to clipboard
fun output(outputOp: Output)

Chain-like setter to set up outputOp.

predict
Link copied to clipboard
fun predict(dataset: Dataset): List<Int>

Predicts labels for all observation in dataset.

open override fun predict(inputData: FloatArray): Int

Generates output prediction for the input sample.

abstract fun predict(inputData: FloatArray, predictionTensorName: String): Int

Generates output prediction for the input sample using output of the predictionTensorName tensor.

abstract fun predict(dataset: Dataset, batchSize: Int, callbacks: List<Callback> = listOf()): IntArray
fun predict(dataset: Dataset, batchSize: Int, callback: Callback): IntArray

Generates output predictions for the input samples.

predictAndGetActivations
Link copied to clipboard
abstract fun predictAndGetActivations(inputData: FloatArray, predictionTensorName: String = ""): Pair<Int, List<*>>

Predicts and returns not only prediction but list of activations values from intermediate model layers (for visualisation or debugging purposes).

predictSoftly
Link copied to clipboard
open override fun predictSoftly(inputData: FloatArray, predictionTensorName: String): FloatArray

Predicts vector of probabilities instead of specific class in predict method.

abstract fun predictSoftly(dataset: Dataset, batchSize: Int, callbacks: List<Callback> = listOf()): Array<FloatArray>
fun predictSoftly(dataset: Dataset, batchSize: Int, callback: Callback): Array<FloatArray>

Generates output predictions for the input samples. Each prediction is a vector of probabilities instead of specific class in predict method.

reshape
Link copied to clipboard
open override fun reshape(vararg dims: Long)

Chain-like setter to set up input shape.

save
Link copied to clipboard
abstract fun save(modelDirectory: File, savingFormat: SavingFormat = SavingFormat.TF_GRAPH_CUSTOM_VARIABLES, saveOptimizerState: Boolean = false, writingMode: WritingMode = WritingMode.FAIL_IF_EXISTS)

Saves the model as graph and weights.

summary
Link copied to clipboard
abstract fun summary(): ModelSummary

Returns model summary.

toString
Link copied to clipboard
open override fun toString(): String

Properties

inputDimensions
Link copied to clipboard
open override val inputDimensions: LongArray

Input specification for this model.

isBuiltForForwardMode
Link copied to clipboard
var isBuiltForForwardMode: Boolean = false

Is true when model is ready for forward mode.

isModelCompiled
Link copied to clipboard
var isModelCompiled: Boolean = false

Is true when model is compiled.

isModelInitialized
Link copied to clipboard
var isModelInitialized: Boolean = false

Is true when model is initialized.

isOptimizerVariableInitialized
Link copied to clipboard
var isOptimizerVariableInitialized: Boolean = false

Is true when model optimizer variables are initialized.

kGraph
Link copied to clipboard
lateinit var kGraph: KGraph

TensorFlow wrapped computational graph.

loss
Link copied to clipboard
var loss: LossFunction

Loss function.

name
Link copied to clipboard
var name: String? = null

Model name.

numberOfClasses
Link copied to clipboard
var numberOfClasses: Long

Number of classes for classification tasks. -1 is a default value for regression tasks.

shape
Link copied to clipboard
lateinit var shape: LongArray

Data shape for prediction.

stopTraining
Link copied to clipboard
var stopTraining: Boolean = false

Special flag for callbacks.

Inheritors

GraphTrainableModel
Link copied to clipboard

Extensions

logSummary
Link copied to clipboard
fun TrainableModel.logSummary(logger: Logger = ModelSummaryLogger.logger)

Formats and log model summary to logger By defaults prints to ModelSummaryLogger

printSummary
Link copied to clipboard
fun TrainableModel.printSummary(out: PrintStream = System.out)

Formats and prints model summary to output stream By defaults prints to console