Merge branch 'master' of https://scm.cri.ensmp.fr/git/Faustine
[Faustine.git] / interpreter / faustio.ml
index 7df74a6..9539bf3 100644 (file)
@@ -12,16 +12,20 @@ open Signal;;
 open Beam;;
 open Aux;;
 
 open Beam;;
 open Aux;;
 
+exception Faustine_IO_Error of string;;
+
 let csv_read_buffer_length = 0xFFFF;;
 
 class virtual io = 
   object
 let csv_read_buffer_length = 0xFFFF;;
 
 class virtual io = 
   object
+    val mutable _filename = ""
     val mutable _basename = ""
     val mutable _dir = ""
     val mutable _basename = ""
     val mutable _dir = ""
-    method set : string -> string -> unit = 
-      fun (dir : string) ->
-       fun (basename : string) ->
-         _basename <- basename; _dir <- dir
+    method set : string -> string -> string -> unit = 
+      fun (filename : string) ->
+       fun (dir : string) ->
+         fun (basename : string) ->
+           _filename <- filename; _basename <- basename; _dir <- dir
 
     method virtual read : string array -> beam
     method virtual write : rate array -> data -> string array
 
     method virtual read : string array -> beam
     method virtual write : rate array -> data -> string array
@@ -70,11 +74,18 @@ class waveio : io_type =
       fun (rates : rate array) ->
        fun (output : data) ->
          let n = Array.length output in          
       fun (rates : rate array) ->
        fun (output : data) ->
          let n = Array.length output in          
-         let paths = Array.init n (fun i -> 
-           _dir ^ _basename ^ (string_of_int (i + 1)) ^ ".wav") in
+         let paths = 
+           if _filename = "" then 
+             Array.init n (fun i -> 
+               _dir ^ _basename ^ (string_of_int (i + 1)) ^ ".wav") 
+           else if n = 1 then 
+             let () = Unix.unlink _filename in [|_filename|]
+           else raise (Faustine_IO_Error ("The process has several output signals, 
+                       however stdout supports only one output signal. Please remove 
+                       the '> " ^ _filename ^ "'.")) in
          let get_freq = fun (r : rate) -> r#to_int in
          let freqs = Array.map get_freq rates in
          let get_freq = fun (r : rate) -> r#to_int in
          let freqs = Array.map get_freq rates in
-
+         
          let files = 
            let channels = self#channels output in 
            let file_format = Sndfile.format 
          let files = 
            let channels = self#channels output in 
            let file_format = Sndfile.format 
@@ -124,12 +135,16 @@ class csvio : io_type =
     method write : rate array -> data -> string array = 
       fun (rates : rate array) ->
        fun (data : data) ->
     method write : rate array -> data -> string array = 
       fun (rates : rate array) ->
        fun (data : data) ->
+         let n = Array.length data in
          let paths = 
          let paths = 
-           let n = Array.length data in
-           let path_pattern = fun i -> 
-             _dir ^ _basename ^ (string_of_int (i + 1)) ^ ".csv" in
-           Array.init n path_pattern in          
-
+           if _filename = "" then 
+             Array.init n (fun i -> 
+               _dir ^ _basename ^ (string_of_int (i + 1)) ^ ".csv") 
+           else if n = 1 then 
+             let () = Unix.unlink _filename in [|_filename|]
+           else raise (Faustine_IO_Error ("The process has several output signals, 
+                       however stdout supports only one output signal. Please remove 
+                       the '> " ^ _filename ^ "'.")) in
          let files = Array.map open_out paths in
          let strings = 
            let value2string : float array -> string =
          let files = Array.map open_out paths in
          let strings = 
            let value2string : float array -> string =
@@ -151,19 +166,14 @@ class iomanager =
   object (self)
     val wave = new waveio
     val csv = new csvio
   object (self)
     val wave = new waveio
     val csv = new csvio
+    val mutable _output_filename = ""
     val mutable _dir = ""
     val mutable _format = ""
     val mutable _basename = ""
 
     val mutable _dir = ""
     val mutable _format = ""
     val mutable _basename = ""
 
-    method private grab_format : string -> string = 
-      fun (path : string) ->
-       let fragments = Str.split (Str.regexp "\.") path in
-       let n = List.length fragments in
-       List.nth fragments (n - 1)      
-
     method read : string list -> beam_type = 
       fun (paths : string list) ->
     method read : string list -> beam_type = 
       fun (paths : string list) ->
-       let formats = List.map self#grab_format paths in
+       let formats = List.map format_of_file paths in
        let read_one : string -> string -> beam_type = 
          fun (format : string) ->
            fun (path : string) ->
        let read_one : string -> string -> beam_type = 
          fun (format : string) ->
            fun (path : string) ->
@@ -175,25 +185,34 @@ class iomanager =
          fun b1 -> fun b2 -> b1#append b2 in
        List.fold_left concat (new beam [||]) beams
              
          fun b1 -> fun b2 -> b1#append b2 in
        List.fold_left concat (new beam [||]) beams
              
-    method set : string -> string -> string -> unit = 
-      fun (dir : string) ->
-       fun (format : string) ->
-         fun (basename : string) ->
-           _dir <- dir; 
-           _format <- format; 
-           _basename <- basename;
-           wave#set _dir _basename;
-           csv#set _dir _basename
+    method set : string -> string -> string -> string -> unit = 
+      fun (filename : string) ->
+       fun (dir : string) ->
+         fun (format : string) ->
+           fun (basename : string) ->
+             _output_filename <- filename;
+             _dir <- dir; 
+             _format <- format; 
+             _basename <- basename;
+             wave#set _output_filename _dir _basename;
+             csv#set _output_filename _dir _basename
 
     method write : rate array -> data -> string array = 
       fun (rates : rate array) ->
        fun (data : data) ->
 
     method write : rate array -> data -> string array = 
       fun (rates : rate array) ->
        fun (data : data) ->
-         if _format = "" then
-           raise (Invalid_argument "output format unset.")
-         else if _format = "wav" then 
-           wave#write rates data
-         else if _format = "csv" then
-           csv#write rates data 
-         else raise (Invalid_argument "unknown format.")
-       
+         if _output_filename = "" then (
+           if _format = "" then
+             raise (Invalid_argument "output format unset.")
+           else if _format = "wav" then 
+             wave#write rates data
+           else if _format = "csv" then
+             csv#write rates data 
+           else raise (Invalid_argument "unknown format."))
+         else (
+           let format = format_of_file _output_filename in
+           if format = "wav" then
+             wave#write rates data
+           else if format = "csv" then
+             csv#write rates data
+           else raise (Invalid_argument ("unknown format" ^ format)))
   end;;
   end;;