diff --git a/bmap.go b/bmap.go index ee6459e..e25ead4 100644 --- a/bmap.go +++ b/bmap.go @@ -12,11 +12,15 @@ import ( "fmt" "io" "io/ioutil" + "net/http" "os" + "regexp" "strings" "syscall" "github.com/frostschutz/go-fibmap" + //"github.com/uli-go/xz/lzma" + "code.google.com/p/lzma" ) const defaultBMAPHash = "0000000000000000000000000000000000000000" @@ -33,7 +37,7 @@ type BlockRange struct { type Range struct { XMLName xml.Name `xml:"Range"` Range string `xml:",chardata"` - Hash string `xml:"sha1,attr"` + Hash string `xml:"sha1,attr,omitempty"` } // MarshalXML is used to convert BlockRange to a Range object for XML output @@ -108,7 +112,7 @@ type BMap struct { } // NewBMap creates a new BMap representation by reading an image file -func NewBMap(filename string) (BMap, error) { +func NewBMap(filename string, addChecksum bool) (BMap, error) { stat := syscall.Stat_t{} syscall.Stat(filename, &stat) blockSize := stat.Blksize @@ -119,7 +123,7 @@ func NewBMap(filename string) (BMap, error) { return BMap{}, err } - blockMap, total := getBlockMap(fd, blockSize) + blockMap, total := getBlockMap(fd, blockSize, addChecksum) return BMap{Version: "1.3", ImageSize: stat.Size, BlockSize: blockSize, BlocksCount: stat.Size / blockSize, MappedBlocksCount: total, BmapFileSHA1: defaultBMAPHash, BlockMap: BlockMap{blockMap}}, nil } @@ -157,8 +161,8 @@ func (b *BMap) Write(outputFilename string) error { return err } -// Load reads and converts a bmap file to a BMap object -func Load(filename string) (BMap, error) { +// LoadFromFile reads and converts a bmap file to a BMap object +func LoadFromFile(filename string, verify bool) (BMap, error) { b := BMap{} data, err := ioutil.ReadFile(filename) if err != nil { @@ -170,30 +174,126 @@ func Load(filename string) (BMap, error) { return b, err } - // Verify SHA Hash + if verify { + // Verify SHA Hash - bmapHash := b.BmapFileSHA1 + bmapHash := b.BmapFileSHA1 - b.BmapFileSHA1 = defaultBMAPHash - output, err := b.XMLOutput() - if err != nil { - //fmt.Printf("error: %v\n", err) - return b, err + b.BmapFileSHA1 = defaultBMAPHash + output, err := b.XMLOutput() + if err != nil { + //fmt.Printf("error: %v\n", err) + return b, err + } + + if bmapHash != getSHA1Hash(bytes.NewBuffer(output)) { + return b, errors.New("XML SHA Signature does not match") + } + + b.BmapFileSHA1 = bmapHash } - if bmapHash != getSHA1Hash(bytes.NewBuffer(output)) { - return b, errors.New("XML SHA Signature does not match") - } - - b.BmapFileSHA1 = bmapHash - return b, err } -// Copy copies an input file and uses the BMap data to create a new image file -func (b *BMap) Copy(input string, output string) error { +// LoadFromReader reads and converts a bmap in an io.Reader to a BMap object +func LoadFromReader(r io.Reader, verify bool) (BMap, error) { + b := BMap{} + data, err := ioutil.ReadAll(r) + if err != nil { + return b, err + } + + err = xml.Unmarshal(data, &b) + if err != nil { + return b, err + } + + if verify { + // Verify SHA Hash + + bmapHash := b.BmapFileSHA1 + + b.BmapFileSHA1 = defaultBMAPHash + output, err := b.XMLOutput() + if err != nil { + //fmt.Printf("error: %v\n", err) + return b, err + } + + if bmapHash != getSHA1Hash(bytes.NewBuffer(output)) { + return b, errors.New("XML SHA Signature does not match") + } + + b.BmapFileSHA1 = bmapHash + } + + return b, err +} + +// loadInput is designed to handle loading local and remote files +func loadInput(input string) (io.ReadCloser, error) { + var reader io.ReadCloser + inFile, err := os.Open(input) - defer inFile.Close() + //defer inFile.Close() + reader = inFile + if err != nil { // Not local file. Try to grab remote file + switch { + case strings.HasPrefix(input, "http"): + //fmt.Println("Get URL", input) + resp, err := http.Get(input) + if err != nil { + return reader, errors.New("Unable to download image from " + input) + } + reader = resp.Body + //fmt.Println("Retrieved", resp.ContentLength, "bytes") + + } + //return reader, err + } + + return reader, nil +} + +// decompressInput "unwraps" - aka decompresses and unarchives image files +func decompressInput(input string, r io.Reader) (io.Reader, error) { + var reader io.Reader + var err error + + switch { + case strings.HasSuffix(input, ".tar.gz"), strings.HasSuffix(input, ".tgz"): + + reader, err = getTarGZReader(r) + + case strings.HasSuffix(input, ".gz"), strings.HasSuffix(input, ".gzip"): + + reader, err = getGZReader(r) + + case strings.HasSuffix(input, ".xz"), strings.HasSuffix(input, ".lzma"): + + reader, err = getLZMAReader(r) + + default: + reader = r + + } + + return reader, err +} + +// Copy copies an input file and uses the BMap data to create a new image file +func (b *BMap) Copy(input string, output string, verify bool) error { + + var reader io.Reader + + r, err := loadInput(input) + defer r.Close() + if err != nil { + return err + } + + reader, err = decompressInput(input, r) if err != nil { return err } @@ -203,46 +303,93 @@ func (b *BMap) Copy(input string, output string) error { if err != nil { return err } - err = outFile.Truncate(b.ImageSize) + + block, err := isBlockDevice(outFile) if err != nil { return err } - var reader io.ReadSeeker - - switch { - case strings.HasSuffix(input, ".gz"): - - reader, err = getGZReader(inFile) + if !block { // Can't truncate block devices + err = outFile.Truncate(b.ImageSize) if err != nil { return err } - - default: - reader = inFile } + place := int64(0) for _, block := range b.BlockMap.Range { - reader.Seek(block.Start*b.BlockSize, 0) - outFile.Seek(block.Start*b.BlockSize, 0) - length := (block.End - block.Start + 1) * b.BlockSize - written, err := io.CopyN(outFile, reader, length) + + diff := block.Start*b.BlockSize - place + place = block.Start * b.BlockSize + + _, err := moveReaderForward(reader, diff) if err != nil { return err - } - if written != length { - return errors.New("Unable to copy") + _, err = outFile.Seek(diff, os.SEEK_CUR) + if err != nil { + return err + } + length := (block.End - block.Start + 1) * b.BlockSize + place += length + if len(block.Hash) != 0 && verify { + //println("checking hash") + // Verify hash sum + //var checksumReader bytes.Buffer + checksumReader := io.TeeReader(reader, outFile) + h := sha1.New() + io.CopyN(h, checksumReader, length) + if block.Hash != hex.EncodeToString(h.Sum(nil)) { + return fmt.Errorf("Checksum mismatch for blockrange %d-%d", block.Start, block.End) + } + } else { + written, err := io.CopyN(outFile, reader, length) + if err != nil { + return err + } + if written != length { + return errors.New("Unable to copy") + + } } } + outFile.Sync() + return nil +} + +// Copy is used to copy images to a destination when a bmap file is unavailable +func Copy(input string, output string) error { + + reader, err := loadInput(input) + defer reader.Close() + if err != nil { + return nil + } + + seeker, err := decompressInput(input, reader) + if err != nil { + return err + } + + outFile, err := os.Create(output) + defer outFile.Close() + if err != nil { + return err + } + + io.Copy(outFile, seeker) + + outFile.Sync() + return nil } func getTarReader(reader io.Reader) (io.ReadSeeker, error) { tarReader := tar.NewReader(reader) + tarReader.Next() data, err := ioutil.ReadAll(tarReader) if err != nil { @@ -252,45 +399,56 @@ func getTarReader(reader io.Reader) (io.ReadSeeker, error) { } -func getTarGZReader(reader io.Reader) (io.ReadSeeker, error) { +func getTarGZReader(reader io.Reader) (io.Reader, error) { gzReader, err := gzip.NewReader(reader) if err != nil { return nil, err } tarReader := tar.NewReader(gzReader) + tarReader.Next() - data, err := ioutil.ReadAll(tarReader) - if err != nil { - return nil, err - } - return bytes.NewReader(data), nil + return tarReader, nil } -func getGZReader(reader io.Reader) (io.ReadSeeker, error) { - gzReader, err := gzip.NewReader(reader) - if err != nil { - return nil, err - } - data, err := ioutil.ReadAll(gzReader) - if err != nil { - return nil, err - } - return bytes.NewReader(data), nil +func getGZReader(reader io.Reader) (io.Reader, error) { + return gzip.NewReader(reader) } -func getBZReader(reader io.Reader) (io.ReadSeeker, error) { +func getTarBZReader(reader io.Reader) (io.Reader, error) { bzReader := bzip2.NewReader(reader) - data, err := ioutil.ReadAll(bzReader) - if err != nil { - return nil, err - } - return bytes.NewReader(data), nil + tarReader := tar.NewReader(bzReader) + tarReader.Next() + + return tarReader, nil + +} + +func getBZReader(reader io.Reader) (io.Reader, error) { + bzReader := bzip2.NewReader(reader) + + return bzReader, nil +} + +func getLZMAReader(reader io.Reader) (io.Reader, error) { + + /* + return lzma.NewReader(reader) + + */ + xzReader := lzma.NewReader(reader) + + return xzReader, nil +} + +func getXZReader(reader io.Reader) (io.ReadSeeker, error) { + //return xz.NewSeekReader(reader) + return nil, nil } // getBlockMap finds all of the ranges of written blocks in a file/image -func getBlockMap(fd *os.File, blockSize int64) ([]BlockRange, int64) { +func getBlockMap(fd *os.File, blockSize int64, addChecksum bool) ([]BlockRange, int64) { //fd, _ := os.Open(filename) f := fibmap.NewFibmapFile(fd) @@ -305,10 +463,12 @@ func getBlockMap(fd *os.File, blockSize int64) ([]BlockRange, int64) { length := v / blockSize currentBlockRange.End = currentBlockRange.Start + length - 1 - h := sha1.New() - fd.Seek(currentBlockRange.Start*blockSize, 0) - io.CopyN(h, fd, v) - currentBlockRange.Hash = hex.EncodeToString(h.Sum(nil)) + if addChecksum { + h := sha1.New() + fd.Seek(currentBlockRange.Start*blockSize, 0) + io.CopyN(h, fd, v) + currentBlockRange.Hash = hex.EncodeToString(h.Sum(nil)) + } blockMap[i/2] = currentBlockRange mappedBlocks += length @@ -323,3 +483,43 @@ func getSHA1Hash(r io.Reader) string { io.Copy(h, r) return hex.EncodeToString(h.Sum(nil)) } + +// moveReaderForward provides seek capabilities for an io.Reader, even ones that aren't an io.Seeker +func moveReaderForward(r io.Reader, count int64) (int64, error) { + + seeker, ok := r.(io.Seeker) + if ok { + return seeker.Seek(count, os.SEEK_CUR) + } else { + return io.CopyN(ioutil.Discard, r, count) + } + +} + +// isBlockDevice checks if a file descriptor is a block device +func isBlockDevice(fd *os.File) (bool, error) { + block := false + s, err := fd.Stat() + if err != nil { + return block, err + } + block = (s.Mode() & os.ModeDevice) != 0 + return block, nil +} + +// CleanBMap cleans spaces out of bmaptools bmap file and updates SHA Hash +func CleanBMap(input string) error { + data, err := ioutil.ReadFile(input) + if err != nil { + return err + } + re := regexp.MustCompile(`>\s*([\d\w-]+)\s*<`) + fixed := re.ReplaceAllString(string(data), ">${1}<") + b, err := LoadFromReader(bytes.NewReader([]byte(fixed)), false) + if err != nil { + return err + } + b.BmapFileSHA1 = defaultBMAPHash + return b.Write(input) + +} diff --git a/tool/bmap.go b/tool/bmap.go index 9e99df4..a9b2fd8 100644 --- a/tool/bmap.go +++ b/tool/bmap.go @@ -1,27 +1,124 @@ package main -import "dev.justinjudd.org/justin/bmap" +import ( + "fmt" + "path" + "strings" -const imageInputFilename = "/tmp/file.img" -const GZCompressedImageInputFilename = "/tmp/file.gz" -const bmapOutputFilename = "/tmp/test.bmap" -const imageOutputFilename = "/tmp/test.img" + "dev.justinjudd.org/justin/bmap" + "github.com/spf13/cobra" +) func main() { - /* - b := bmap.NewBMap(imageInputFilename) - b.Write("/tmp/test.bmap") - */ - - b, err := bmap.Load(bmapOutputFilename) - if err != nil { - println(err.Error()) + var rootCmd = &cobra.Command{ + Use: "bmap", + Short: "Create block map (bmap) files and copy files.", + Long: "Create block map (bmap) files and copy files.", + Run: func(cmd *cobra.Command, args []string) { + cmd.Help() + }, } - //fmt.Printf("%#v\n", b) - //err = b.Copy(imageInputFilename, imageOutputFilename) - err = b.Copy(GZCompressedImageInputFilename, imageOutputFilename) + + var output string + var noChecksum bool + + var createCmd = &cobra.Command{ + Use: "create image", + Short: "Create block map (bmap) file.", + Long: "Create block map (bmap) file.", + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 1 { + fmt.Println("Must supply image file") + return + } + create(args[0], output, !noChecksum) + }, + } + + createCmd.Flags().StringVarP(&output, "output", "o", "", "output file name") + createCmd.Flags().BoolVar(&noChecksum, "no-checksum", false, "don't generate block sum for block ranges") + + var bmapFile string + var nobmap bool + var noVerify bool + + var copyCmd = &cobra.Command{ + Use: "copy image dest", + Short: "Write an image using bmap .", + Long: "Write an image using bmap.", + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 2 { + fmt.Println("Must supply image file and destination file") + return + } + copy(args[0], args[1], bmapFile, nobmap, !noVerify) + }, + } + + copyCmd.Flags().StringVar(&bmapFile, "bmap", "", "block map file for the image") + copyCmd.Flags().BoolVar(&nobmap, "nobmap", false, "allow copying withut a bmap") + copyCmd.Flags().BoolVar(&noVerify, "no-verify", false, "do not verify the checksum of data before writing.") + + rootCmd.AddCommand(createCmd, copyCmd) + + var cleanCmd = &cobra.Command{ + Use: "clean input", + Short: "clean a bmap file.", + Long: "clean a bmap file.", + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 1 { + fmt.Println("Must supply bmap file") + return + } + clean(args[0]) + }, + } + rootCmd.AddCommand(cleanCmd) + + rootCmd.Execute() + +} + +func create(image string, output string, noChecksum bool) { + b, err := bmap.NewBMap(image, noChecksum) if err != nil { - println(err.Error()) + println("Error reading image file:", err.Error()) + } + err = b.Write(output) + if err != nil { + println("Error writing bmap file", err.Error()) + } +} + +func copy(image string, dest string, bmapFile string, nobmap bool, noVerify bool) { + if !nobmap { + if len(bmapFile) == 0 { + bmapFile = strings.TrimSuffix(image, path.Ext(image)) + ".bmap" + } + b, err := bmap.LoadFromFile(bmapFile, !noVerify) + if err != nil { + fmt.Println("Unable to read bmap file: ", bmapFile) + return + } + err = b.Copy(image, dest, noVerify) + if err != nil { + fmt.Println("Error copying image: ", err.Error()) + return + } + } else { + err := bmap.Copy(image, dest) + if err != nil { + fmt.Println("Error copying image: ", err.Error()) + return + } + } + +} + +func clean(input string) { + err := bmap.CleanBMap(input) + if err != nil { + fmt.Println(err.Error()) } }